From db71dbd1feb48b13aa84945df12374f5eb784779 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Thu, 2 Jul 2015 16:06:45 -0700 Subject: [PATCH 1/4] handle special characters in elements in crosstab --- .../spark/sql/DataFrameNaFunctions.scala | 2 +- .../sql/execution/stat/StatFunctions.scala | 20 +++++++++++--- .../apache/spark/sql/DataFrameStatSuite.scala | 27 +++++++++++++++++++ 3 files changed, 44 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index b4c2daa05586..8681a56c82f1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -391,7 +391,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * Returns a [[Column]] expression that replaces null value in `col` with `replacement`. */ private def fillCol[T](col: StructField, replacement: T): Column = { - coalesce(df.col(col.name), lit(replacement).cast(col.dataType)).as(col.name) + coalesce(df.col("`" + col.name + "`"), lit(replacement).cast(col.dataType)).as(col.name) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index 23ddfa9839e5..67489e6d34fd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -110,8 +110,12 @@ private[sql] object StatFunctions extends Logging { logWarning("The maximum limit of 1e6 pairs have been collected, which may not be all of " + "the pairs. Please try reducing the amount of distinct items in your columns.") } + def cleanElement(element: Any): String = { + if (element == null) "" else element.toString + } // get the distinct values of column 2, so that we can make them the column names - val distinctCol2: Map[Any, Int] = counts.map(_.get(1)).distinct.zipWithIndex.toMap + val distinctCol2: Map[Any, Int] = + counts.map(e => cleanElement(e.get(1))).distinct.zipWithIndex.toMap val columnSize = distinctCol2.size require(columnSize < 1e4, s"The number of distinct values for $col2, can't " + s"exceed 1e4. Currently $columnSize") @@ -121,15 +125,23 @@ private[sql] object StatFunctions extends Logging { // row.get(0) is column 1 // row.get(1) is column 2 // row.get(2) is the frequency - countsRow.setLong(distinctCol2.get(row.get(1)).get + 1, row.getLong(2)) + val columnIndex = distinctCol2.get(cleanElement(row.get(1))).get + countsRow.setLong(columnIndex + 1, row.getLong(2)) } // the value of col1 is the first value, the rest are the counts - countsRow.update(0, UTF8String.fromString(col1Item.toString)) + countsRow.update(0, UTF8String.fromString(cleanElement(col1Item))) countsRow }.toSeq + // Back ticks can't exist in DataFrame column names, therefore drop them. To be able to accept + // special keywords and `.`, wrap the column names in ``. + def cleanColumnName(name: String): String = { + name.replace("`", "") + } // In the map, the column names (._1) are not ordered by the index (._2). This was the bug in // SPARK-8681. We need to explicitly sort by the column index and assign the column names. - val headerNames = distinctCol2.toSeq.sortBy(_._2).map(r => StructField(r._1.toString, LongType)) + val headerNames = distinctCol2.toSeq.sortBy(_._2).map { r => + StructField(cleanColumnName(r._1.toString), LongType) + } val schema = StructType(StructField(tableName, StringType) +: headerNames) new DataFrame(df.sqlContext, LocalRelation(schema.toAttributes, table)).na.fill(0.0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 765094da6bda..549a85ee3c89 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -85,6 +85,33 @@ class DataFrameStatSuite extends SparkFunSuite { } } + test("special crosstab elements (., '', null, ``)") { + val data = Seq( + ("a", 1, "ho"), + (null, 2, "ho"), + ("a.b", 1, ""), + ("b", 2, "`ha`"), + ("a", 1, null) + ) + val df = data.toDF("1", "2", "3") + val ct1 = df.stat.crosstab("1", "2") + // column fields should be 1 + distinct elements of second column + assert(ct1.schema.fields.length === 3) + assert(ct1.collect().length === 4) + val ct2 = df.stat.crosstab("1", "3") + assert(ct2.schema.fields.length === 4) + assert(ct2.schema.fieldNames.contains("ha")) + assert(ct2.collect().length === 4) + val ct3 = df.stat.crosstab("3", "2") + assert(ct3.schema.fields.length === 3) + assert(ct3.collect().length === 4) + val ct4 = df.stat.crosstab("3", "1") + assert(ct4.schema.fields.length === 5) + assert(ct4.schema.fieldNames.contains("")) + assert(ct4.schema.fieldNames.contains("a.b")) + assert(ct4.collect().length === 4) + } + test("Frequent Items") { val rows = Seq.tabulate(1000) { i => if (i % 3 == 0) (1, toLetter(1), -1.0) else (i, toLetter(i), i * -1.0) From 9dba6ce5844fc3033d0e7c8e17babcb31270ae6f Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Thu, 2 Jul 2015 18:02:37 -0700 Subject: [PATCH 2/4] address cr1 --- .../scala/org/apache/spark/sql/DataFrameStatFunctions.scala | 3 +++ .../org/apache/spark/sql/execution/stat/StatFunctions.scala | 2 +- .../test/scala/org/apache/spark/sql/DataFrameStatSuite.scala | 2 +- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index edb9ed7bba56..9e8c27a1a0b6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -78,6 +78,9 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * The first column of each row will be the distinct values of `col1` and the column names will * be the distinct values of `col2`. The name of the first column will be `$col1_$col2`. Counts * will be returned as `Long`s. Pairs that have no occurrences will have `null` as their counts. + * Null elements will be replaced by "null", and back ticks will be dropped from elements if they + * exist. + * * * @param col1 The name of the first column. Distinct items will make the first item of * each row. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index 67489e6d34fd..00231d65a7d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -111,7 +111,7 @@ private[sql] object StatFunctions extends Logging { "the pairs. Please try reducing the amount of distinct items in your columns.") } def cleanElement(element: Any): String = { - if (element == null) "" else element.toString + if (element == null) "null" else element.toString } // get the distinct values of column 2, so that we can make them the column names val distinctCol2: Map[Any, Int] = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 549a85ee3c89..bfe14b4b70e7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -99,7 +99,7 @@ class DataFrameStatSuite extends SparkFunSuite { assert(ct1.schema.fields.length === 3) assert(ct1.collect().length === 4) val ct2 = df.stat.crosstab("1", "3") - assert(ct2.schema.fields.length === 4) + assert(ct2.schema.fields.length === 5) assert(ct2.schema.fieldNames.contains("ha")) assert(ct2.collect().length === 4) val ct3 = df.stat.crosstab("3", "2") From 93a0d3fa2725749d6013aeae987247cf45937593 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Thu, 2 Jul 2015 19:46:04 -0700 Subject: [PATCH 3/4] added tests for NaN and Infinity --- .../apache/spark/sql/DataFrameStatSuite.scala | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index bfe14b4b70e7..7ba4ba73e0cc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -87,27 +87,30 @@ class DataFrameStatSuite extends SparkFunSuite { test("special crosstab elements (., '', null, ``)") { val data = Seq( - ("a", 1, "ho"), - (null, 2, "ho"), - ("a.b", 1, ""), - ("b", 2, "`ha`"), - ("a", 1, null) + ("a", Double.NaN, "ho"), + (null, 2.0, "ho"), + ("a.b", Double.NegativeInfinity, ""), + ("b", Double.PositiveInfinity, "`ha`"), + ("a", 1.0, null) ) val df = data.toDF("1", "2", "3") val ct1 = df.stat.crosstab("1", "2") // column fields should be 1 + distinct elements of second column - assert(ct1.schema.fields.length === 3) + assert(ct1.schema.fields.length === 6) assert(ct1.collect().length === 4) val ct2 = df.stat.crosstab("1", "3") assert(ct2.schema.fields.length === 5) assert(ct2.schema.fieldNames.contains("ha")) assert(ct2.collect().length === 4) val ct3 = df.stat.crosstab("3", "2") - assert(ct3.schema.fields.length === 3) + assert(ct3.schema.fields.length === 6) + assert(ct3.schema.fieldNames.contains("NaN")) + assert(ct3.schema.fieldNames.contains("Infinity")) + assert(ct3.schema.fieldNames.contains("-Infinity")) assert(ct3.collect().length === 4) val ct4 = df.stat.crosstab("3", "1") assert(ct4.schema.fields.length === 5) - assert(ct4.schema.fieldNames.contains("")) + assert(ct4.schema.fieldNames.contains("null")) assert(ct4.schema.fieldNames.contains("a.b")) assert(ct4.collect().length === 4) } From e06b8407d847a0811f224287e78c0ebf5c199497 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Thu, 2 Jul 2015 20:30:22 -0700 Subject: [PATCH 4/4] fix scalastyle --- .../scala/org/apache/spark/sql/DataFrameStatFunctions.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index 9e8c27a1a0b6..587869e57f96 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -79,8 +79,8 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * be the distinct values of `col2`. The name of the first column will be `$col1_$col2`. Counts * will be returned as `Long`s. Pairs that have no occurrences will have `null` as their counts. * Null elements will be replaced by "null", and back ticks will be dropped from elements if they - * exist. - * + * exist. + * * * @param col1 The name of the first column. Distinct items will make the first item of * each row.