Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -462,35 +462,54 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
})
}

private[this] def cast(from: DataType, to: DataType): Any => Any = to match {
case dt if dt == from => identity[Any]
case StringType => castToString(from)
case BinaryType => castToBinary(from)
case DateType => castToDate(from)
case decimal: DecimalType => castToDecimal(from, decimal)
case TimestampType => castToTimestamp(from)
case CalendarIntervalType => castToInterval(from)
case BooleanType => castToBoolean(from)
case ByteType => castToByte(from)
case ShortType => castToShort(from)
case IntegerType => castToInt(from)
case FloatType => castToFloat(from)
case LongType => castToLong(from)
case DoubleType => castToDouble(from)
case array: ArrayType => castArray(from.asInstanceOf[ArrayType].elementType, array.elementType)
case map: MapType => castMap(from.asInstanceOf[MapType], map)
case struct: StructType => castStruct(from.asInstanceOf[StructType], struct)
case udt: UserDefinedType[_]
if udt.userClass == from.asInstanceOf[UserDefinedType[_]].userClass =>
identity[Any]
case _: UserDefinedType[_] =>
throw new SparkException(s"Cannot cast $from to $to.")
private[this] def cast(from: DataType, to: DataType): Any => Any = {
// If the cast does not change the structure, then we don't really need to cast anything.
// We can return what the children return. Same thing should happen in the codegen path.
if (DataType.equalsStructurally(from, to)) {
identity
} else {
to match {
case dt if dt == from => identity[Any]
case StringType => castToString(from)
case BinaryType => castToBinary(from)
case DateType => castToDate(from)
case decimal: DecimalType => castToDecimal(from, decimal)
case TimestampType => castToTimestamp(from)
case CalendarIntervalType => castToInterval(from)
case BooleanType => castToBoolean(from)
case ByteType => castToByte(from)
case ShortType => castToShort(from)
case IntegerType => castToInt(from)
case FloatType => castToFloat(from)
case LongType => castToLong(from)
case DoubleType => castToDouble(from)
case array: ArrayType =>
castArray(from.asInstanceOf[ArrayType].elementType, array.elementType)
case map: MapType => castMap(from.asInstanceOf[MapType], map)
case struct: StructType => castStruct(from.asInstanceOf[StructType], struct)
case udt: UserDefinedType[_]
if udt.userClass == from.asInstanceOf[UserDefinedType[_]].userClass =>
identity[Any]
case _: UserDefinedType[_] =>
throw new SparkException(s"Cannot cast $from to $to.")
}
}
}

private[this] lazy val cast: Any => Any = cast(child.dataType, dataType)

protected override def nullSafeEval(input: Any): Any = cast(input)

override def genCode(ctx: CodegenContext): ExprCode = {
// If the cast does not change the structure, then we don't really need to cast anything.
// We can return what the children return. Same thing should happen in the interpreted path.
if (DataType.equalsStructurally(child.dataType, dataType)) {
child.genCode(ctx)
} else {
super.genCode(ctx)
}
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val eval = child.genCode(ctx)
val nullSafeCast = nullSafeCastFunction(child.dataType, dataType, ctx)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -288,4 +288,30 @@ object DataType {
case (fromDataType, toDataType) => fromDataType == toDataType
}
}

/**
* Returns true if the two data types share the same "shape", i.e. the types (including
* nullability) are the same, but the field names don't need to be the same.
*/
def equalsStructurally(from: DataType, to: DataType): Boolean = {
(from, to) match {
case (left: ArrayType, right: ArrayType) =>
equalsStructurally(left.elementType, right.elementType) &&
left.containsNull == right.containsNull
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we be more flexible here? i.e. !left.containsNull || right.containsNull

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's not symmetric. equalsStructurally should be symmetric, unless we rename this something else (e.g. structurallyCastable)


case (left: MapType, right: MapType) =>
equalsStructurally(left.keyType, right.keyType) &&
equalsStructurally(left.valueType, right.valueType) &&
left.valueContainsNull == right.valueContainsNull

case (StructType(fromFields), StructType(toFields)) =>
fromFields.length == toFields.length &&
fromFields.zip(toFields)
.forall { case (l, r) =>
equalsStructurally(l.dataType, r.dataType) && l.nullable == r.nullable
}

case (fromDataType, toDataType) => fromDataType == toDataType
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -813,4 +813,18 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
assert(cast(1.0.toFloat, DateType).checkInputDataTypes().isFailure)
assert(cast(1.0, DateType).checkInputDataTypes().isFailure)
}

test("SPARK-20302 cast with same structure") {
val from = new StructType()
.add("a", IntegerType)
.add("b", new StructType().add("b1", LongType))

val to = new StructType()
.add("a1", IntegerType)
.add("b1", new StructType().add("b11", LongType))

val input = Row(10, Row(12L))

checkEvaluation(cast(Literal.create(input, from), to), input)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -411,4 +411,35 @@ class DataTypeSuite extends SparkFunSuite {
checkCatalogString(ArrayType(createStruct(40)))
checkCatalogString(MapType(IntegerType, StringType))
checkCatalogString(MapType(IntegerType, createStruct(40)))

def checkEqualsStructurally(from: DataType, to: DataType, expected: Boolean): Unit = {
val testName = s"equalsStructurally: (from: $from, to: $to)"
test(testName) {
assert(DataType.equalsStructurally(from, to) === expected)
}
}

checkEqualsStructurally(BooleanType, BooleanType, true)
checkEqualsStructurally(IntegerType, IntegerType, true)
checkEqualsStructurally(IntegerType, LongType, false)
checkEqualsStructurally(ArrayType(IntegerType, true), ArrayType(IntegerType, true), true)
checkEqualsStructurally(ArrayType(IntegerType, true), ArrayType(IntegerType, false), false)

checkEqualsStructurally(
new StructType().add("f1", IntegerType),
new StructType().add("f2", IntegerType),
true)
checkEqualsStructurally(
new StructType().add("f1", IntegerType),
new StructType().add("f2", IntegerType, false),
false)

checkEqualsStructurally(
new StructType().add("f1", IntegerType).add("f", new StructType().add("f2", StringType)),
new StructType().add("f2", IntegerType).add("g", new StructType().add("f1", StringType)),
true)
checkEqualsStructurally(
new StructType().add("f1", IntegerType).add("f", new StructType().add("f2", StringType, false)),
new StructType().add("f2", IntegerType).add("g", new StructType().add("f1", StringType)),
false)
}