Skip to content

Commit db71dbd

Browse files
committed
handle special characters in elements in crosstab
1 parent 34d448d commit db71dbd

File tree

3 files changed

+44
-5
lines changed

3 files changed

+44
-5
lines changed

sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
391391
* Returns a [[Column]] expression that replaces null value in `col` with `replacement`.
392392
*/
393393
private def fillCol[T](col: StructField, replacement: T): Column = {
394-
coalesce(df.col(col.name), lit(replacement).cast(col.dataType)).as(col.name)
394+
coalesce(df.col("`" + col.name + "`"), lit(replacement).cast(col.dataType)).as(col.name)
395395
}
396396

397397
/**

sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,12 @@ private[sql] object StatFunctions extends Logging {
110110
logWarning("The maximum limit of 1e6 pairs have been collected, which may not be all of " +
111111
"the pairs. Please try reducing the amount of distinct items in your columns.")
112112
}
113+
def cleanElement(element: Any): String = {
114+
if (element == null) "" else element.toString
115+
}
113116
// get the distinct values of column 2, so that we can make them the column names
114-
val distinctCol2: Map[Any, Int] = counts.map(_.get(1)).distinct.zipWithIndex.toMap
117+
val distinctCol2: Map[Any, Int] =
118+
counts.map(e => cleanElement(e.get(1))).distinct.zipWithIndex.toMap
115119
val columnSize = distinctCol2.size
116120
require(columnSize < 1e4, s"The number of distinct values for $col2, can't " +
117121
s"exceed 1e4. Currently $columnSize")
@@ -121,15 +125,23 @@ private[sql] object StatFunctions extends Logging {
121125
// row.get(0) is column 1
122126
// row.get(1) is column 2
123127
// row.get(2) is the frequency
124-
countsRow.setLong(distinctCol2.get(row.get(1)).get + 1, row.getLong(2))
128+
val columnIndex = distinctCol2.get(cleanElement(row.get(1))).get
129+
countsRow.setLong(columnIndex + 1, row.getLong(2))
125130
}
126131
// the value of col1 is the first value, the rest are the counts
127-
countsRow.update(0, UTF8String.fromString(col1Item.toString))
132+
countsRow.update(0, UTF8String.fromString(cleanElement(col1Item)))
128133
countsRow
129134
}.toSeq
135+
// Back ticks can't exist in DataFrame column names, therefore drop them. To be able to accept
136+
// special keywords and `.`, wrap the column names in ``.
137+
def cleanColumnName(name: String): String = {
138+
name.replace("`", "")
139+
}
130140
// In the map, the column names (._1) are not ordered by the index (._2). This was the bug in
131141
// SPARK-8681. We need to explicitly sort by the column index and assign the column names.
132-
val headerNames = distinctCol2.toSeq.sortBy(_._2).map(r => StructField(r._1.toString, LongType))
142+
val headerNames = distinctCol2.toSeq.sortBy(_._2).map { r =>
143+
StructField(cleanColumnName(r._1.toString), LongType)
144+
}
133145
val schema = StructType(StructField(tableName, StringType) +: headerNames)
134146

135147
new DataFrame(df.sqlContext, LocalRelation(schema.toAttributes, table)).na.fill(0.0)

sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,33 @@ class DataFrameStatSuite extends SparkFunSuite {
8585
}
8686
}
8787

88+
test("special crosstab elements (., '', null, ``)") {
89+
val data = Seq(
90+
("a", 1, "ho"),
91+
(null, 2, "ho"),
92+
("a.b", 1, ""),
93+
("b", 2, "`ha`"),
94+
("a", 1, null)
95+
)
96+
val df = data.toDF("1", "2", "3")
97+
val ct1 = df.stat.crosstab("1", "2")
98+
// column fields should be 1 + distinct elements of second column
99+
assert(ct1.schema.fields.length === 3)
100+
assert(ct1.collect().length === 4)
101+
val ct2 = df.stat.crosstab("1", "3")
102+
assert(ct2.schema.fields.length === 4)
103+
assert(ct2.schema.fieldNames.contains("ha"))
104+
assert(ct2.collect().length === 4)
105+
val ct3 = df.stat.crosstab("3", "2")
106+
assert(ct3.schema.fields.length === 3)
107+
assert(ct3.collect().length === 4)
108+
val ct4 = df.stat.crosstab("3", "1")
109+
assert(ct4.schema.fields.length === 5)
110+
assert(ct4.schema.fieldNames.contains(""))
111+
assert(ct4.schema.fieldNames.contains("a.b"))
112+
assert(ct4.collect().length === 4)
113+
}
114+
88115
test("Frequent Items") {
89116
val rows = Seq.tabulate(1000) { i =>
90117
if (i % 3 == 0) (1, toLetter(1), -1.0) else (i, toLetter(i), i * -1.0)

0 commit comments

Comments
 (0)