Skip to content

Commit 6a05386

Browse files
committed
Save
Fixed use unnamed version
1 parent 80a3e13 commit 6a05386

File tree

3 files changed

+23
-10
lines changed

3 files changed

+23
-10
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,28 +43,30 @@ class RowBasedHashMapGenerator(
4343
extends HashMapGenerator (ctx, aggregateExpressions, generatedClassName,
4444
groupingKeySchema, bufferSchema) {
4545

46-
protected def initializeAggregateHashMap(): String = {
46+
override protected def initializeAggregateHashMap(): String = {
4747
val generatedKeySchema: String =
4848
s"new org.apache.spark.sql.types.StructType()" +
4949
groupingKeySchema.map { key =>
50+
val keyName = ctx.addReferenceObj(key.name)
5051
key.dataType match {
5152
case d: DecimalType =>
52-
s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.createDecimalType(
53+
s""".add("$keyName", org.apache.spark.sql.types.DataTypes.createDecimalType(
5354
|${d.precision}, ${d.scale}))""".stripMargin
5455
case _ =>
55-
s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.${key.dataType})"""
56+
s""".add("$keyName", org.apache.spark.sql.types.DataTypes.${key.dataType})"""
5657
}
5758
}.mkString("\n").concat(";")
5859

5960
val generatedValueSchema: String =
6061
s"new org.apache.spark.sql.types.StructType()" +
6162
bufferSchema.map { key =>
63+
val keyName = ctx.addReferenceObj(key.name)
6264
key.dataType match {
6365
case d: DecimalType =>
64-
s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.createDecimalType(
66+
s""".add("$keyName", org.apache.spark.sql.types.DataTypes.createDecimalType(
6567
|${d.precision}, ${d.scale}))""".stripMargin
6668
case _ =>
67-
s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.${key.dataType})"""
69+
s""".add("$keyName", org.apache.spark.sql.types.DataTypes.${key.dataType})"""
6870
}
6971
}.mkString("\n").concat(";")
7072

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,28 +48,30 @@ class VectorizedHashMapGenerator(
4848
extends HashMapGenerator (ctx, aggregateExpressions, generatedClassName,
4949
groupingKeySchema, bufferSchema) {
5050

51-
protected def initializeAggregateHashMap(): String = {
51+
override protected def initializeAggregateHashMap(): String = {
5252
val generatedSchema: String =
5353
s"new org.apache.spark.sql.types.StructType()" +
5454
(groupingKeySchema ++ bufferSchema).map { key =>
55+
val keyName = ctx.addReferenceObj(key.name)
5556
key.dataType match {
5657
case d: DecimalType =>
57-
s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.createDecimalType(
58+
s""".add("$keyName", org.apache.spark.sql.types.DataTypes.createDecimalType(
5859
|${d.precision}, ${d.scale}))""".stripMargin
5960
case _ =>
60-
s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.${key.dataType})"""
61+
s""".add("$keyName", org.apache.spark.sql.types.DataTypes.${key.dataType})"""
6162
}
6263
}.mkString("\n").concat(";")
6364

6465
val generatedAggBufferSchema: String =
6566
s"new org.apache.spark.sql.types.StructType()" +
6667
bufferSchema.map { key =>
68+
val keyName = ctx.addReferenceObj(key.name)
6769
key.dataType match {
6870
case d: DecimalType =>
69-
s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.createDecimalType(
71+
s""".add("$keyName", org.apache.spark.sql.types.DataTypes.createDecimalType(
7072
|${d.precision}, ${d.scale}))""".stripMargin
7173
case _ =>
72-
s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.${key.dataType})"""
74+
s""".add("$keyName", org.apache.spark.sql.types.DataTypes.${key.dataType})"""
7375
}
7476
}.mkString("\n").concat(";")
7577

sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,15 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
9797
)
9898
}
9999

100+
test("SPARK-18952: regexes fail codegen when used as keys due to bad forward-slash escapes") {
101+
val df = Seq(("some[thing]", "random-string")).toDF("key", "val")
102+
103+
checkAnswer(
104+
df.groupBy(regexp_extract('key, "([a-z]+)\\[", 1)).count(),
105+
Row("some", 1) :: Nil
106+
)
107+
}
108+
100109
test("rollup") {
101110
checkAnswer(
102111
courseSales.rollup("course", "year").sum("earnings"),

0 commit comments

Comments
 (0)