Skip to content

Commit 3d2134f

Browse files
yjshenrxin
authored andcommitted
[SPARK-9055][SQL] WidenTypes should also support Intersect and Except
JIRA: https://issues.apache.org/jira/browse/SPARK-9055 cc rxin Author: Yijie Shen <[email protected]> Closes apache#7491 from yijieshen/widen and squashes the following commits: 079fa52 [Yijie Shen] widenType support for intersect and expect
1 parent cdc36ee commit 3d2134f

File tree

3 files changed

+94
-41
lines changed

3 files changed

+94
-41
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala

Lines changed: 53 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.analysis
2020
import javax.annotation.Nullable
2121

2222
import org.apache.spark.sql.catalyst.expressions._
23-
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Union}
23+
import org.apache.spark.sql.catalyst.plans.logical._
2424
import org.apache.spark.sql.catalyst.rules.Rule
2525
import org.apache.spark.sql.types._
2626

@@ -168,52 +168,65 @@ object HiveTypeCoercion {
168168
* - LongType to DoubleType
169169
*/
170170
object WidenTypes extends Rule[LogicalPlan] {
171-
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
172-
// TODO: unions with fixed-precision decimals
173-
case u @ Union(left, right) if u.childrenResolved && !u.resolved =>
174-
val castedInput = left.output.zip(right.output).map {
175-
// When a string is found on one side, make the other side a string too.
176-
case (lhs, rhs) if lhs.dataType == StringType && rhs.dataType != StringType =>
177-
(lhs, Alias(Cast(rhs, StringType), rhs.name)())
178-
case (lhs, rhs) if lhs.dataType != StringType && rhs.dataType == StringType =>
179-
(Alias(Cast(lhs, StringType), lhs.name)(), rhs)
180171

181-
case (lhs, rhs) if lhs.dataType != rhs.dataType =>
182-
logDebug(s"Resolving mismatched union input ${lhs.dataType}, ${rhs.dataType}")
183-
findTightestCommonTypeOfTwo(lhs.dataType, rhs.dataType).map { widestType =>
184-
val newLeft =
185-
if (lhs.dataType == widestType) lhs else Alias(Cast(lhs, widestType), lhs.name)()
186-
val newRight =
187-
if (rhs.dataType == widestType) rhs else Alias(Cast(rhs, widestType), rhs.name)()
188-
189-
(newLeft, newRight)
190-
}.getOrElse {
191-
// If there is no applicable conversion, leave expression unchanged.
192-
(lhs, rhs)
193-
}
172+
private[this] def widenOutputTypes(planName: String, left: LogicalPlan, right: LogicalPlan):
173+
(LogicalPlan, LogicalPlan) = {
174+
175+
// TODO: with fixed-precision decimals
176+
val castedInput = left.output.zip(right.output).map {
177+
// When a string is found on one side, make the other side a string too.
178+
case (lhs, rhs) if lhs.dataType == StringType && rhs.dataType != StringType =>
179+
(lhs, Alias(Cast(rhs, StringType), rhs.name)())
180+
case (lhs, rhs) if lhs.dataType != StringType && rhs.dataType == StringType =>
181+
(Alias(Cast(lhs, StringType), lhs.name)(), rhs)
182+
183+
case (lhs, rhs) if lhs.dataType != rhs.dataType =>
184+
logDebug(s"Resolving mismatched $planName input ${lhs.dataType}, ${rhs.dataType}")
185+
findTightestCommonTypeOfTwo(lhs.dataType, rhs.dataType).map { widestType =>
186+
val newLeft =
187+
if (lhs.dataType == widestType) lhs else Alias(Cast(lhs, widestType), lhs.name)()
188+
val newRight =
189+
if (rhs.dataType == widestType) rhs else Alias(Cast(rhs, widestType), rhs.name)()
190+
191+
(newLeft, newRight)
192+
}.getOrElse {
193+
// If there is no applicable conversion, leave expression unchanged.
194+
(lhs, rhs)
195+
}
194196

195-
case other => other
196-
}
197+
case other => other
198+
}
197199

198-
val (castedLeft, castedRight) = castedInput.unzip
200+
val (castedLeft, castedRight) = castedInput.unzip
199201

200-
val newLeft =
201-
if (castedLeft.map(_.dataType) != left.output.map(_.dataType)) {
202-
logDebug(s"Widening numeric types in union $castedLeft ${left.output}")
203-
Project(castedLeft, left)
204-
} else {
205-
left
206-
}
202+
val newLeft =
203+
if (castedLeft.map(_.dataType) != left.output.map(_.dataType)) {
204+
logDebug(s"Widening numeric types in $planName $castedLeft ${left.output}")
205+
Project(castedLeft, left)
206+
} else {
207+
left
208+
}
207209

208-
val newRight =
209-
if (castedRight.map(_.dataType) != right.output.map(_.dataType)) {
210-
logDebug(s"Widening numeric types in union $castedRight ${right.output}")
211-
Project(castedRight, right)
212-
} else {
213-
right
214-
}
210+
val newRight =
211+
if (castedRight.map(_.dataType) != right.output.map(_.dataType)) {
212+
logDebug(s"Widening numeric types in $planName $castedRight ${right.output}")
213+
Project(castedRight, right)
214+
} else {
215+
right
216+
}
217+
(newLeft, newRight)
218+
}
215219

220+
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
221+
case u @ Union(left, right) if u.childrenResolved && !u.resolved =>
222+
val (newLeft, newRight) = widenOutputTypes(u.nodeName, left, right)
216223
Union(newLeft, newRight)
224+
case e @ Except(left, right) if e.childrenResolved && !e.resolved =>
225+
val (newLeft, newRight) = widenOutputTypes(e.nodeName, left, right)
226+
Except(newLeft, newRight)
227+
case i @ Intersect(left, right) if i.childrenResolved && !i.resolved =>
228+
val (newLeft, newRight) = widenOutputTypes(i.nodeName, left, right)
229+
Intersect(newLeft, newRight)
217230
}
218231
}
219232

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,10 @@ case class BroadcastHint(child: LogicalPlan) extends UnaryNode {
141141

142142
case class Except(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {
143143
override def output: Seq[Attribute] = left.output
144+
145+
override lazy val resolved: Boolean =
146+
childrenResolved &&
147+
left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType }
144148
}
145149

146150
case class InsertIntoTable(
@@ -437,4 +441,8 @@ case object OneRowRelation extends LeafNode {
437441

438442
case class Intersect(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {
439443
override def output: Seq[Attribute] = left.output
444+
445+
override lazy val resolved: Boolean =
446+
childrenResolved &&
447+
left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType }
440448
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.analysis
2020
import org.apache.spark.sql.catalyst.plans.PlanTest
2121

2222
import org.apache.spark.sql.catalyst.expressions._
23-
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LocalRelation, Project}
23+
import org.apache.spark.sql.catalyst.plans.logical._
2424
import org.apache.spark.sql.catalyst.rules.Rule
2525
import org.apache.spark.sql.types._
2626

@@ -305,6 +305,38 @@ class HiveTypeCoercionSuite extends PlanTest {
305305
)
306306
}
307307

308+
test("WidenTypes for union except and intersect") {
309+
def checkOutput(logical: LogicalPlan, expectTypes: Seq[DataType]): Unit = {
310+
logical.output.zip(expectTypes).foreach { case (attr, dt) =>
311+
assert(attr.dataType === dt)
312+
}
313+
}
314+
315+
val left = LocalRelation(
316+
AttributeReference("i", IntegerType)(),
317+
AttributeReference("u", DecimalType.Unlimited)(),
318+
AttributeReference("b", ByteType)(),
319+
AttributeReference("d", DoubleType)())
320+
val right = LocalRelation(
321+
AttributeReference("s", StringType)(),
322+
AttributeReference("d", DecimalType(2, 1))(),
323+
AttributeReference("f", FloatType)(),
324+
AttributeReference("l", LongType)())
325+
326+
val wt = HiveTypeCoercion.WidenTypes
327+
val expectedTypes = Seq(StringType, DecimalType.Unlimited, FloatType, DoubleType)
328+
329+
val r1 = wt(Union(left, right)).asInstanceOf[Union]
330+
val r2 = wt(Except(left, right)).asInstanceOf[Except]
331+
val r3 = wt(Intersect(left, right)).asInstanceOf[Intersect]
332+
checkOutput(r1.left, expectedTypes)
333+
checkOutput(r1.right, expectedTypes)
334+
checkOutput(r2.left, expectedTypes)
335+
checkOutput(r2.right, expectedTypes)
336+
checkOutput(r3.left, expectedTypes)
337+
checkOutput(r3.right, expectedTypes)
338+
}
339+
308340
/**
309341
* There are rules that need to not fire before child expressions get resolved.
310342
* We use this test to make sure those rules do not fire early.

0 commit comments

Comments
 (0)