Skip to content

Commit 42f496f

Browse files
maropuHyukjinKwon
authored andcommitted
[SPARK-31526][SQL][TESTS] Add a new test suite for ExpressionInfo
### What changes were proposed in this pull request? This PR intends to add a new test suite for `ExpressionInfo`. Major changes are as follows; - Added a new test suite named `ExpressionInfoSuite` - To improve test coverage, added a test for error handling in `ExpressionInfoSuite` - Moved the `ExpressionInfo`-related tests from `UDFSuite` to `ExpressionInfoSuite` - Moved the related tests from `SQLQuerySuite` to `ExpressionInfoSuite` - Added a comment in `ExpressionInfoSuite` (followup of #28224) ### Why are the changes needed? To improve test suites/coverage. ### Does this PR introduce any user-facing change? No. ### How was this patch tested? Added tests. Closes #28308 from maropu/SPARK-31526. Authored-by: Takeshi Yamamuro <[email protected]> Signed-off-by: HyukjinKwon <[email protected]>
1 parent f093480 commit 42f496f

File tree

4 files changed

+162
-111
lines changed

4 files changed

+162
-111
lines changed

sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionDescription.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,12 @@
103103
String arguments() default "";
104104
String examples() default "";
105105
String note() default "";
106+
/**
107+
* Valid group names are almost the same with one defined as `groupname` in
108+
* `sql/functions.scala`. But, `collection_funcs` is split into fine-grained three groups:
109+
* `array_funcs`, `map_funcs`, and `json_funcs`. See `ExpressionInfo` for the
110+
* detailed group names.
111+
*/
106112
String group() default "";
107113
String since() default "";
108114
String deprecated() default "";

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 0 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,13 @@ import java.net.{MalformedURLException, URL}
2222
import java.sql.{Date, Timestamp}
2323
import java.util.concurrent.atomic.AtomicBoolean
2424

25-
import scala.collection.parallel.immutable.ParVector
26-
2725
import org.apache.spark.{AccumulatorSuite, SparkException}
2826
import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
2927
import org.apache.spark.sql.catalyst.expressions.GenericRow
3028
import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, Partial}
3129
import org.apache.spark.sql.catalyst.optimizer.{ConvertToLocalRelation, NestedColumnAliasingSuite}
3230
import org.apache.spark.sql.catalyst.plans.logical.Project
3331
import org.apache.spark.sql.catalyst.util.StringUtils
34-
import org.apache.spark.sql.execution.HiveResult.hiveResultString
3532
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
3633
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
3734
import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
@@ -126,83 +123,6 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
126123
}
127124
}
128125

129-
test("using _FUNC_ instead of function names in examples") {
130-
val exampleRe = "(>.*;)".r
131-
val setStmtRe = "(?i)^(>\\s+set\\s+).+".r
132-
val ignoreSet = Set(
133-
// Examples for CaseWhen show simpler syntax:
134-
// `CASE WHEN ... THEN ... WHEN ... THEN ... END`
135-
"org.apache.spark.sql.catalyst.expressions.CaseWhen",
136-
// _FUNC_ is replaced by `locate` but `locate(... IN ...)` is not supported
137-
"org.apache.spark.sql.catalyst.expressions.StringLocate",
138-
// _FUNC_ is replaced by `%` which causes a parsing error on `SELECT %(2, 1.8)`
139-
"org.apache.spark.sql.catalyst.expressions.Remainder",
140-
// Examples demonstrate alternative names, see SPARK-20749
141-
"org.apache.spark.sql.catalyst.expressions.Length")
142-
spark.sessionState.functionRegistry.listFunction().foreach { funcId =>
143-
val info = spark.sessionState.catalog.lookupFunctionInfo(funcId)
144-
val className = info.getClassName
145-
withClue(s"Expression class '$className'") {
146-
val exprExamples = info.getOriginalExamples
147-
if (!exprExamples.isEmpty && !ignoreSet.contains(className)) {
148-
assert(exampleRe.findAllIn(exprExamples).toIterable
149-
.filter(setStmtRe.findFirstIn(_).isEmpty) // Ignore SET commands
150-
.forall(_.contains("_FUNC_")))
151-
}
152-
}
153-
}
154-
}
155-
156-
test("check outputs of expression examples") {
157-
def unindentAndTrim(s: String): String = {
158-
s.replaceAll("\n\\s+", "\n").trim
159-
}
160-
val beginSqlStmtRe = " > ".r
161-
val endSqlStmtRe = ";\n".r
162-
def checkExampleSyntax(example: String): Unit = {
163-
val beginStmtNum = beginSqlStmtRe.findAllIn(example).length
164-
val endStmtNum = endSqlStmtRe.findAllIn(example).length
165-
assert(beginStmtNum === endStmtNum,
166-
"The number of ` > ` does not match to the number of `;`")
167-
}
168-
val exampleRe = """^(.+);\n(?s)(.+)$""".r
169-
val ignoreSet = Set(
170-
// One of examples shows getting the current timestamp
171-
"org.apache.spark.sql.catalyst.expressions.UnixTimestamp",
172-
// Random output without a seed
173-
"org.apache.spark.sql.catalyst.expressions.Rand",
174-
"org.apache.spark.sql.catalyst.expressions.Randn",
175-
"org.apache.spark.sql.catalyst.expressions.Shuffle",
176-
"org.apache.spark.sql.catalyst.expressions.Uuid",
177-
// The example calls methods that return unstable results.
178-
"org.apache.spark.sql.catalyst.expressions.CallMethodViaReflection")
179-
180-
val parFuncs = new ParVector(spark.sessionState.functionRegistry.listFunction().toVector)
181-
parFuncs.foreach { funcId =>
182-
// Examples can change settings. We clone the session to prevent tests clashing.
183-
val clonedSpark = spark.cloneSession()
184-
// Coalescing partitions can change result order, so disable it.
185-
clonedSpark.sessionState.conf.setConf(SQLConf.COALESCE_PARTITIONS_ENABLED, false)
186-
val info = clonedSpark.sessionState.catalog.lookupFunctionInfo(funcId)
187-
val className = info.getClassName
188-
if (!ignoreSet.contains(className)) {
189-
withClue(s"Function '${info.getName}', Expression class '$className'") {
190-
val example = info.getExamples
191-
checkExampleSyntax(example)
192-
example.split(" > ").toList.foreach(_ match {
193-
case exampleRe(sql, output) =>
194-
val df = clonedSpark.sql(sql)
195-
val actual = unindentAndTrim(
196-
hiveResultString(df.queryExecution.executedPlan).mkString("\n"))
197-
val expected = unindentAndTrim(output)
198-
assert(actual === expected)
199-
case _ =>
200-
})
201-
}
202-
}
203-
}
204-
}
205-
206126
test("SPARK-6743: no columns from cache") {
207127
Seq(
208128
(83, 0, 38),

sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@ package org.apache.spark.sql
2020
import java.math.BigDecimal
2121

2222
import org.apache.spark.sql.api.java._
23-
import org.apache.spark.sql.catalyst.FunctionIdentifier
24-
import org.apache.spark.sql.catalyst.expressions.ExpressionInfo
2523
import org.apache.spark.sql.catalyst.plans.logical.Project
2624
import org.apache.spark.sql.execution.{QueryExecution, SimpleMode}
2725
import org.apache.spark.sql.execution.columnar.InMemoryRelation
@@ -534,35 +532,6 @@ class UDFSuite extends QueryTest with SharedSparkSession {
534532
assert(spark.range(2).select(nonDeterministicJavaUDF()).distinct().count() == 2)
535533
}
536534

537-
test("Replace _FUNC_ in UDF ExpressionInfo") {
538-
val info = spark.sessionState.catalog.lookupFunctionInfo(FunctionIdentifier("upper"))
539-
assert(info.getName === "upper")
540-
assert(info.getClassName === "org.apache.spark.sql.catalyst.expressions.Upper")
541-
assert(info.getUsage === "upper(str) - Returns `str` with all characters changed to uppercase.")
542-
assert(info.getExamples.contains("> SELECT upper('SparkSql');"))
543-
assert(info.getSince === "1.0.1")
544-
assert(info.getNote === "")
545-
assert(info.getExtended.contains("> SELECT upper('SparkSql');"))
546-
}
547-
548-
test("group info in ExpressionInfo") {
549-
val info = spark.sessionState.catalog.lookupFunctionInfo(FunctionIdentifier("sum"))
550-
assert(info.getGroup === "agg_funcs")
551-
552-
Seq("agg_funcs", "array_funcs", "datetime_funcs", "json_funcs", "map_funcs", "window_funcs")
553-
.foreach { groupName =>
554-
val info = new ExpressionInfo(
555-
"testClass", null, "testName", null, "", "", "", groupName, "", "")
556-
assert(info.getGroup === groupName)
557-
}
558-
559-
val errMsg = intercept[IllegalArgumentException] {
560-
val invalidGroupName = "invalid_group_funcs"
561-
new ExpressionInfo("testClass", null, "testName", null, "", "", "", invalidGroupName, "", "")
562-
}.getMessage
563-
assert(errMsg.contains("'group' is malformed in the expression [testName]."))
564-
}
565-
566535
test("SPARK-28521 error message for CAST(parameter types contains DataType)") {
567536
val e = intercept[AnalysisException] {
568537
spark.sql("SELECT CAST(1)")
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.expressions
19+
20+
import scala.collection.parallel.immutable.ParVector
21+
22+
import org.apache.spark.SparkFunSuite
23+
import org.apache.spark.sql.catalyst.FunctionIdentifier
24+
import org.apache.spark.sql.catalyst.expressions.ExpressionInfo
25+
import org.apache.spark.sql.execution.HiveResult.hiveResultString
26+
import org.apache.spark.sql.internal.SQLConf
27+
import org.apache.spark.sql.test.SharedSparkSession
28+
29+
class ExpressionInfoSuite extends SparkFunSuite with SharedSparkSession {
30+
31+
test("Replace _FUNC_ in ExpressionInfo") {
32+
val info = spark.sessionState.catalog.lookupFunctionInfo(FunctionIdentifier("upper"))
33+
assert(info.getName === "upper")
34+
assert(info.getClassName === "org.apache.spark.sql.catalyst.expressions.Upper")
35+
assert(info.getUsage === "upper(str) - Returns `str` with all characters changed to uppercase.")
36+
assert(info.getExamples.contains("> SELECT upper('SparkSql');"))
37+
assert(info.getSince === "1.0.1")
38+
assert(info.getNote === "")
39+
assert(info.getExtended.contains("> SELECT upper('SparkSql');"))
40+
}
41+
42+
test("group info in ExpressionInfo") {
43+
val info = spark.sessionState.catalog.lookupFunctionInfo(FunctionIdentifier("sum"))
44+
assert(info.getGroup === "agg_funcs")
45+
46+
Seq("agg_funcs", "array_funcs", "datetime_funcs", "json_funcs", "map_funcs", "window_funcs")
47+
.foreach { groupName =>
48+
val info = new ExpressionInfo(
49+
"testClass", null, "testName", null, "", "", "", groupName, "", "")
50+
assert(info.getGroup === groupName)
51+
}
52+
53+
val errMsg = intercept[IllegalArgumentException] {
54+
val invalidGroupName = "invalid_group_funcs"
55+
new ExpressionInfo("testClass", null, "testName", null, "", "", "", invalidGroupName, "", "")
56+
}.getMessage
57+
assert(errMsg.contains("'group' is malformed in the expression [testName]."))
58+
}
59+
60+
test("error handling in ExpressionInfo") {
61+
val errMsg1 = intercept[IllegalArgumentException] {
62+
val invalidNote = " invalid note"
63+
new ExpressionInfo("testClass", null, "testName", null, "", "", invalidNote, "", "", "")
64+
}.getMessage
65+
assert(errMsg1.contains("'note' is malformed in the expression [testName]."))
66+
67+
val errMsg2 = intercept[IllegalArgumentException] {
68+
val invalidSince = "-3.0.0"
69+
new ExpressionInfo("testClass", null, "testName", null, "", "", "", "", invalidSince, "")
70+
}.getMessage
71+
assert(errMsg2.contains("'since' is malformed in the expression [testName]."))
72+
73+
val errMsg3 = intercept[IllegalArgumentException] {
74+
val invalidDeprecated = " invalid deprecated"
75+
new ExpressionInfo("testClass", null, "testName", null, "", "", "", "", "", invalidDeprecated)
76+
}.getMessage
77+
assert(errMsg3.contains("'deprecated' is malformed in the expression [testName]."))
78+
}
79+
80+
test("using _FUNC_ instead of function names in examples") {
81+
val exampleRe = "(>.*;)".r
82+
val setStmtRe = "(?i)^(>\\s+set\\s+).+".r
83+
val ignoreSet = Set(
84+
// Examples for CaseWhen show simpler syntax:
85+
// `CASE WHEN ... THEN ... WHEN ... THEN ... END`
86+
"org.apache.spark.sql.catalyst.expressions.CaseWhen",
87+
// _FUNC_ is replaced by `locate` but `locate(... IN ...)` is not supported
88+
"org.apache.spark.sql.catalyst.expressions.StringLocate",
89+
// _FUNC_ is replaced by `%` which causes a parsing error on `SELECT %(2, 1.8)`
90+
"org.apache.spark.sql.catalyst.expressions.Remainder",
91+
// Examples demonstrate alternative names, see SPARK-20749
92+
"org.apache.spark.sql.catalyst.expressions.Length")
93+
spark.sessionState.functionRegistry.listFunction().foreach { funcId =>
94+
val info = spark.sessionState.catalog.lookupFunctionInfo(funcId)
95+
val className = info.getClassName
96+
withClue(s"Expression class '$className'") {
97+
val exprExamples = info.getOriginalExamples
98+
if (!exprExamples.isEmpty && !ignoreSet.contains(className)) {
99+
assert(exampleRe.findAllIn(exprExamples).toIterable
100+
.filter(setStmtRe.findFirstIn(_).isEmpty) // Ignore SET commands
101+
.forall(_.contains("_FUNC_")))
102+
}
103+
}
104+
}
105+
}
106+
107+
test("check outputs of expression examples") {
108+
def unindentAndTrim(s: String): String = {
109+
s.replaceAll("\n\\s+", "\n").trim
110+
}
111+
val beginSqlStmtRe = " > ".r
112+
val endSqlStmtRe = ";\n".r
113+
def checkExampleSyntax(example: String): Unit = {
114+
val beginStmtNum = beginSqlStmtRe.findAllIn(example).length
115+
val endStmtNum = endSqlStmtRe.findAllIn(example).length
116+
assert(beginStmtNum === endStmtNum,
117+
"The number of ` > ` does not match to the number of `;`")
118+
}
119+
val exampleRe = """^(.+);\n(?s)(.+)$""".r
120+
val ignoreSet = Set(
121+
// One of examples shows getting the current timestamp
122+
"org.apache.spark.sql.catalyst.expressions.UnixTimestamp",
123+
// Random output without a seed
124+
"org.apache.spark.sql.catalyst.expressions.Rand",
125+
"org.apache.spark.sql.catalyst.expressions.Randn",
126+
"org.apache.spark.sql.catalyst.expressions.Shuffle",
127+
"org.apache.spark.sql.catalyst.expressions.Uuid",
128+
// The example calls methods that return unstable results.
129+
"org.apache.spark.sql.catalyst.expressions.CallMethodViaReflection")
130+
131+
val parFuncs = new ParVector(spark.sessionState.functionRegistry.listFunction().toVector)
132+
parFuncs.foreach { funcId =>
133+
// Examples can change settings. We clone the session to prevent tests clashing.
134+
val clonedSpark = spark.cloneSession()
135+
// Coalescing partitions can change result order, so disable it.
136+
clonedSpark.sessionState.conf.setConf(SQLConf.COALESCE_PARTITIONS_ENABLED, false)
137+
val info = clonedSpark.sessionState.catalog.lookupFunctionInfo(funcId)
138+
val className = info.getClassName
139+
if (!ignoreSet.contains(className)) {
140+
withClue(s"Function '${info.getName}', Expression class '$className'") {
141+
val example = info.getExamples
142+
checkExampleSyntax(example)
143+
example.split(" > ").toList.foreach {
144+
case exampleRe(sql, output) =>
145+
val df = clonedSpark.sql(sql)
146+
val actual = unindentAndTrim(
147+
hiveResultString(df.queryExecution.executedPlan).mkString("\n"))
148+
val expected = unindentAndTrim(output)
149+
assert(actual === expected)
150+
case _ =>
151+
}
152+
}
153+
}
154+
}
155+
}
156+
}

0 commit comments

Comments
 (0)