From aa3b4e872363503310150bfcf7dd099287719154 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 15 Jan 2016 16:56:39 -0800 Subject: [PATCH 1/3] fix cast in filter --- .../scala/org/apache/spark/sql/Column.scala | 17 ++++++++++------- .../org/apache/spark/sql/DataFrameSuite.scala | 6 ++++++ 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 97bf7a0cc451..2ab091e40a07 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -133,6 +133,15 @@ class Column(protected[sql] val expr: Expression) extends Logging { case func: UnresolvedFunction => UnresolvedAlias(func, Some(func.prettyString)) + // If we have a top level Cast, there is a chance to give it a better alias, if there is a + // NamedExpression under this Cast. + case c: Cast => c.transformUp { + case Cast(ne: NamedExpression, to) => UnresolvedAlias(Cast(ne, to)) + } match { + case ne: NamedExpression => ne + case other => Alias(expr, expr.prettyString)() + } + case expr: Expression => Alias(expr, expr.prettyString)() } @@ -921,13 +930,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def cast(to: DataType): Column = withExpr { - expr match { - // keeps the name of expression if possible when do cast. - case ne: NamedExpression => UnresolvedAlias(Cast(expr, to)) - case _ => Cast(expr, to) - } - } + def cast(to: DataType): Column = withExpr { Cast(expr, to) } /** * Casts the column to a different data type, using the canonical string representation diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index d6c140dfea9e..6cd887a23035 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1228,4 +1228,10 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { checkAnswer(df.withColumn("col.a", lit("c")), Row("c", "b")) checkAnswer(df.withColumn("col.c", lit("c")), Row("a", "b", "c")) } + + test("SPARK-12841: cast in filter") { + checkAnswer( + Seq(1 -> "a").toDF("i", "j").filter($"i".cast(StringType) === "1"), + Row(1, "a")) + } } From af96c5d3d3a687b6051dacd2b7a9da3ec1767a1f Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 18 Jan 2016 10:25:54 -0800 Subject: [PATCH 2/3] more test --- .../src/test/scala/org/apache/spark/sql/DataFrameSuite.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 6cd887a23035..afc8df07fd9a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1007,6 +1007,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { test("SPARK-10743: keep the name of expression if possible when do cast") { val df = (1 to 10).map(Tuple1.apply).toDF("i").as("src") assert(df.select($"src.i".cast(StringType)).columns.head === "i") + assert(df.select($"src.i".cast(StringType).cast(IntegerType)).columns.head === "i") } test("SPARK-11301: fix case sensitivity for filter on partitioned columns") { From d8681b4fcc8a260d6e5da7240b94b952c74d1d60 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 18 Jan 2016 11:43:52 -0800 Subject: [PATCH 3/3] small fix --- .../scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index dadea6b54a94..9257fba60e36 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -147,7 +147,7 @@ class Analyzer( private def assignAliases(exprs: Seq[NamedExpression]) = { exprs.zipWithIndex.map { case (expr, i) => - expr transform { + expr transformUp { case u @ UnresolvedAlias(child, optionalAliasName) => child match { case ne: NamedExpression => ne case e if !e.resolved => u