Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,20 @@ object HiveTypeCoercion {
val numericPrecedence =
Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, DecimalType)
val allPromotions: Seq[Seq[DataType]] = numericPrecedence :: Nil

def findTightestCommonType(t1: DataType, t2: DataType): Option[DataType] = {
val valueTypes = Seq(t1, t2).filter(t => t != NullType)
if (valueTypes.distinct.size > 1) {
// Try and find a promotion rule that contains both types in question.
val applicableConversion =
HiveTypeCoercion.allPromotions.find(p => p.contains(t1) && p.contains(t2))

// If found return the widest common type, otherwise None
applicableConversion.map(_.filter(t => t == t1 || t == t2).last)
} else {
Some(if (valueTypes.size == 0) NullType else valueTypes.head)
}
}
}

/**
Expand All @@ -51,22 +65,6 @@ trait HiveTypeCoercion {
Division ::
Nil

trait TypeWidening {
def findTightestCommonType(t1: DataType, t2: DataType): Option[DataType] = {
val valueTypes = Seq(t1, t2).filter(t => t != NullType)
if (valueTypes.distinct.size > 1) {
// Try and find a promotion rule that contains both types in question.
val applicableConversion =
HiveTypeCoercion.allPromotions.find(p => p.contains(t1) && p.contains(t2))

// If found return the widest common type, otherwise None
applicableConversion.map(_.filter(t => t == t1 || t == t2).last)
} else {
Some(if (valueTypes.size == 0) NullType else valueTypes.head)
}
}
}

/**
* Applies any changes to [[AttributeReference]] data types that are made by other rules to
* instances higher in the query tree.
Expand Down Expand Up @@ -147,7 +145,8 @@ trait HiveTypeCoercion {
* - LongType to FloatType
* - LongType to DoubleType
*/
object WidenTypes extends Rule[LogicalPlan] with TypeWidening {
object WidenTypes extends Rule[LogicalPlan] {
import HiveTypeCoercion._

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case u @ Union(left, right) if u.childrenResolved && !u.resolved =>
Expand Down Expand Up @@ -343,7 +342,9 @@ trait HiveTypeCoercion {
/**
* Coerces the type of different branches of a CASE WHEN statement to a common type.
*/
object CaseWhenCoercion extends Rule[LogicalPlan] with TypeWidening {
object CaseWhenCoercion extends Rule[LogicalPlan] {
import HiveTypeCoercion._

def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case cw @ CaseWhen(branches) if !cw.resolved && !branches.exists(!_.resolved) =>
val valueTypes = branches.sliding(2, 2).map {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,13 @@ import org.apache.spark.sql.catalyst.types._

class HiveTypeCoercionSuite extends FunSuite {

val rules = new HiveTypeCoercion { }
import rules._

test("tightest common bound for types") {
def widenTest(t1: DataType, t2: DataType, tightestCommon: Option[DataType]) {
var found = WidenTypes.findTightestCommonType(t1, t2)
var found = HiveTypeCoercion.findTightestCommonType(t1, t2)
assert(found == tightestCommon,
s"Expected $tightestCommon as tightest common type for $t1 and $t2, found $found")
// Test both directions to make sure the widening is symmetric.
found = WidenTypes.findTightestCommonType(t2, t1)
found = HiveTypeCoercion.findTightestCommonType(t2, t1)
assert(found == tightestCommon,
s"Expected $tightestCommon as tightest common type for $t2 and $t1, found $found")
}
Expand Down
52 changes: 22 additions & 30 deletions sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -125,39 +125,31 @@ private[sql] object JsonRDD extends Logging {
* Returns the most general data type for two given data types.
*/
private[json] def compatibleType(t1: DataType, t2: DataType): DataType = {
// Try and find a promotion rule that contains both types in question.
val applicableConversion = HiveTypeCoercion.allPromotions.find(p => p.contains(t1) && p
.contains(t2))

// If found return the widest common type, otherwise None
val returnType = applicableConversion.map(_.filter(t => t == t1 || t == t2).last)

if (returnType.isDefined) {
returnType.get
} else {
// t1 or t2 is a StructType, ArrayType, BooleanType, or an unexpected type.
(t1, t2) match {
case (other: DataType, NullType) => other
case (NullType, other: DataType) => other
case (StructType(fields1), StructType(fields2)) => {
val newFields = (fields1 ++ fields2).groupBy(field => field.name).map {
case (name, fieldTypes) => {
val dataType = fieldTypes.map(field => field.dataType).reduce(
(type1: DataType, type2: DataType) => compatibleType(type1, type2))
StructField(name, dataType, true)
HiveTypeCoercion.findTightestCommonType(t1,t2) match {
case Some(commonType) => commonType
case None =>
// t1 or t2 is a StructType, ArrayType, BooleanType, or an unexpected type.
(t1, t2) match {
case (other: DataType, NullType) => other
case (NullType, other: DataType) => other
case (StructType(fields1), StructType(fields2)) => {
val newFields = (fields1 ++ fields2).groupBy(field => field.name).map {
case (name, fieldTypes) => {
val dataType = fieldTypes.map(field => field.dataType).reduce(
(type1: DataType, type2: DataType) => compatibleType(type1, type2))
StructField(name, dataType, true)
}
}
StructType(newFields.toSeq.sortBy {
case StructField(name, _, _) => name
})
}
StructType(newFields.toSeq.sortBy {
case StructField(name, _, _) => name
})
case (ArrayType(elementType1, containsNull1), ArrayType(elementType2, containsNull2)) =>
ArrayType(compatibleType(elementType1, elementType2), containsNull1 || containsNull2)
// TODO: We should use JsonObjectStringType to mark that values of field will be
// strings and every string is a Json object.
case (_, _) => StringType
}
case (ArrayType(elementType1, containsNull1), ArrayType(elementType2, containsNull2)) =>
ArrayType(compatibleType(elementType1, elementType2), containsNull1 || containsNull2)
// TODO: We should use JsonObjectStringType to mark that values of field will be
// strings and every string is a Json object.
case (BooleanType, BooleanType) => BooleanType
case (_, _) => StringType
}
}
}

Expand Down