Skip to content

Commit 1563f12

Browse files
committed
address some comments.
1 parent b539e94 commit 1563f12

File tree

10 files changed

+447
-54
lines changed

10 files changed

+447
-54
lines changed

sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionInfo.java

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ public class ExpressionInfo {
2525
public enum FunctionType {
2626
BUILTIN, PERSISTENT, TEMPORARY;
2727
}
28+
2829
private String className;
2930
private String usage;
3031
private String name;
@@ -65,10 +66,6 @@ public ExpressionInfo(String className, String db, String name, String usage, St
6566
this.functionType = functionType;
6667
}
6768

68-
public ExpressionInfo(String className, String db, String name, String usage, String extended) {
69-
this(className, db, name, usage, extended, FunctionType.TEMPORARY);
70-
}
71-
7269
public ExpressionInfo(String className, String name) {
7370
this(className, null, name, null, null, FunctionType.TEMPORARY);
7471
}

sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,12 @@ class AnalysisException protected[sql] (
5555
s"$message;$lineAnnotation$positionAnnotation"
5656
}
5757
}
58+
59+
object AnalysisException {
60+
/**
61+
* Create a no such temporary macro exception.
62+
*/
63+
def noSuchTempMacroException(func: String): AnalysisException = {
64+
new AnalysisException(s"Temporary macro '$func' not found")
65+
}
66+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ class SystemFunctionRegistry(builtin: SimpleFunctionRegistry) extends SimpleFunc
144144
}
145145

146146
override def listFunction(): Seq[String] = synchronized {
147-
(functionBuilders.iterator.map(_._1).toList ++ builtin.listFunction()).sorted
147+
(functionBuilders.iterator.map(_._1).toList ++ builtin.listFunction()).distinct.sorted
148148
}
149149

150150
override def lookupFunction(name: String): Option[ExpressionInfo] = synchronized {
@@ -160,6 +160,7 @@ class SystemFunctionRegistry(builtin: SimpleFunctionRegistry) extends SimpleFunc
160160
}
161161

162162
override def clear(): Unit = synchronized {
163+
builtin.clear()
163164
functionBuilders.clear()
164165
}
165166

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,3 @@ class NoSuchPartitionsException(db: String, table: String, specs: Seq[TableParti
5252

5353
class NoSuchTempFunctionException(func: String)
5454
extends AnalysisException(s"Temporary function '$func' not found")
55-
56-
class NoSuchTempMacroException(func: String)
57-
extends AnalysisException(s"Temporary macro '$func' not found")

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1102,7 +1102,7 @@ class SessionCatalog(
11021102
/** Drop a temporary macro. */
11031103
def dropTempMacro(name: String, ignoreIfNotExists: Boolean): Unit = {
11041104
if (!functionRegistry.dropFunction(name) && !ignoreIfNotExists) {
1105-
throw new NoSuchTempMacroException(name)
1105+
throw AnalysisException.noSuchTempMacroException(name)
11061106
}
11071107
}
11081108

@@ -1265,8 +1265,8 @@ class SessionCatalog(
12651265
if (func.database.isDefined) {
12661266
dropFunction(func, ignoreIfNotExists = false)
12671267
} else {
1268-
val functionType = functionRegistry.lookupFunction(func.funcName).map(_.getFunctionType)
1269-
.getOrElse(FunctionType.TEMPORARY)
1268+
val functionType = functionRegistry.lookupFunction(func.funcName)
1269+
.map(_.getFunctionType).getOrElse(FunctionType.TEMPORARY)
12701270
if (!functionType.equals(FunctionType.BUILTIN)) {
12711271
dropTempFunction(func.funcName, ignoreIfNotExists = false)
12721272
}

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -724,12 +724,11 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) {
724724
* }}}
725725
*/
726726
override def visitCreateMacro(ctx: CreateMacroContext): LogicalPlan = withOrigin(ctx) {
727-
val arguments = Option(ctx.colTypeList).map(visitColTypeList(_))
728-
.getOrElse(Seq.empty[StructField])
727+
val columns = createSchema(ctx.colTypeList)
729728
val e = expression(ctx.expression)
730729
CreateMacroCommand(
731730
ctx.macroName.getText,
732-
MacroFunctionWrapper(arguments, e))
731+
MacroFunctionWrapper(columns, e))
733732
}
734733

735734
/**

sql/core/src/main/scala/org/apache/spark/sql/execution/command/macros.scala

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,14 @@ package org.apache.spark.sql.execution.command
2020
import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
2121
import org.apache.spark.sql.catalyst.analysis._
2222
import org.apache.spark.sql.catalyst.expressions._
23-
import org.apache.spark.sql.types.StructField
23+
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
24+
import org.apache.spark.sql.types.StructType
2425

2526
/**
2627
* This class provides arguments and body expression of the macro function.
2728
*/
28-
case class MacroFunctionWrapper(columns: Seq[StructField], macroFunction: Expression)
29+
case class MacroFunctionWrapper(columns: StructType, macroFunction: Expression)
30+
2931

3032
/**
3133
* The DDL command that creates a macro.
@@ -41,16 +43,33 @@ case class CreateMacroCommand(
4143

4244
override def run(sparkSession: SparkSession): Seq[Row] = {
4345
val catalog = sparkSession.sessionState.catalog
44-
val columns = funcWrapper.columns.map { col =>
45-
AttributeReference(col.name, col.dataType, col.nullable, col.metadata)() }
46-
val colToIndex: Map[String, Int] = columns.map(_.name).zipWithIndex.toMap
46+
val columns = funcWrapper.columns
47+
val columnAttrs = columns.toAttributes
48+
def formatName: (String => String) =
49+
if (sparkSession.sessionState.conf.caseSensitiveAnalysis) {
50+
(name: String) => name
51+
} else {
52+
(name: String) => name.toLowerCase
53+
}
54+
val colToIndex: Map[String, Int] = columnAttrs.map(_.name).map(formatName).zipWithIndex.toMap
4755
if (colToIndex.size != columns.size) {
4856
throw new AnalysisException(s"Cannot support duplicate colNames " +
4957
s"for CREATE TEMPORARY MACRO $macroName, actual columns: ${columns.mkString(",")}")
5058
}
59+
60+
try {
61+
val plan = Project(Seq(Alias(funcWrapper.macroFunction, "m")()), LocalRelation(columnAttrs))
62+
val analyzed = sparkSession.sessionState.analyzer.execute(plan)
63+
sparkSession.sessionState.analyzer.checkAnalysis(analyzed)
64+
} catch {
65+
case a: AnalysisException =>
66+
throw new AnalysisException(s"CREATE TEMPORARY MACRO $macroName " +
67+
s"with exception: ${a.getMessage}")
68+
}
69+
5170
val macroFunction = funcWrapper.macroFunction.transform {
5271
case u: UnresolvedAttribute =>
53-
val index = colToIndex.get(u.name).getOrElse(
72+
val index = colToIndex.get(formatName(u.name)).getOrElse(
5473
throw new AnalysisException(s"Cannot find colName: ${u} " +
5574
s"for CREATE TEMPORARY MACRO $macroName, actual columns: ${columns.mkString(",")}"))
5675
BoundReference(index, columns(index).dataType, columns(index).nullable)
@@ -64,15 +83,15 @@ case class CreateMacroCommand(
6483
s"for CREATE TEMPORARY MACRO $macroName")
6584
}
6685

67-
val macroInfo = columns.mkString(",") + " -> " + funcWrapper.macroFunction.toString
68-
val info = new ExpressionInfo(macroInfo, macroName)
86+
val columnLength: Int = columns.length
87+
val info = new ExpressionInfo(macroName, macroName)
6988
val builder = (children: Seq[Expression]) => {
70-
if (children.size != columns.size) {
89+
if (children.size != columnLength) {
7190
throw new AnalysisException(s"Actual number of columns: ${children.size} != " +
72-
s"expected number of columns: ${columns.size} for Macro $macroName")
91+
s"expected number of columns: ${columnLength} for Macro $macroName")
7392
}
7493
macroFunction.transform {
75-
// Skip to validate the input type because check it at runtime.
94+
// Skip to validate the input type because check it before.
7695
case b: BoundReference => children(b.ordinal)
7796
}
7897
}
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
CREATE TEMPORARY MACRO SIGMOID (x DOUBLE) 1.0 / (1.0 + EXP(-x));
2+
SELECT SIGMOID(2);
3+
DROP TEMPORARY MACRO SIGMOID;
4+
5+
CREATE TEMPORARY MACRO FIXED_NUMBER() 1;
6+
SELECT FIXED_NUMBER() + 1;
7+
DROP TEMPORARY MACRO FIXED_NUMBER;
8+
9+
CREATE TEMPORARY MACRO SIMPLE_ADD (x INT, y INT) x + y;
10+
SELECT SIMPLE_ADD(1, 9);
11+
DROP TEMPORARY MACRO SIMPLE_ADD;
12+
13+
CREATE TEMPORARY MACRO flr(d bigint) FLOOR(d/10)*10;
14+
SELECT flr(12);
15+
DROP TEMPORARY MACRO flr;
16+
17+
CREATE TEMPORARY MACRO STRING_LEN(x string) length(x);
18+
CREATE TEMPORARY MACRO STRING_LEN_PLUS_ONE(x string) length(x)+1;
19+
CREATE TEMPORARY MACRO STRING_LEN_PLUS_TWO(x string) length(x)+2;
20+
create table macro_test (x string) using parquet;;
21+
insert into table macro_test values ("bb"), ("a"), ("ccc");
22+
SELECT CONCAT(STRING_LEN(x), ":", STRING_LEN_PLUS_ONE(x), ":", STRING_LEN_PLUS_TWO(x)) a
23+
FROM macro_test;
24+
SELECT CONCAT(STRING_LEN(x), ":", STRING_LEN_PLUS_ONE(x), ":", STRING_LEN_PLUS_TWO(x)) a
25+
FROM
26+
macro_test
27+
sort by a;
28+
drop table macro_test;
29+
30+
CREATE TABLE macro_testing(a int, b int, c int) using parquet;;
31+
insert into table macro_testing values (1,2,3);
32+
insert into table macro_testing values (4,5,6);
33+
CREATE TEMPORARY MACRO math_square(x int) x*x;
34+
CREATE TEMPORARY MACRO math_add(x int) x+x;
35+
select math_square(a), math_square(b),factorial(a), factorial(b), math_add(a), math_add(b),int(c)
36+
from macro_testing order by int(c);
37+
drop table macro_testing;
38+
39+
CREATE TEMPORARY MACRO max(x int, y int) x + y;
40+
SELECT max(1, 2);
41+
DROP TEMPORARY MACRO max;
42+
SELECT max(2);
43+
44+
CREATE TEMPORARY MACRO c() 3E9;
45+
SELECT floor(c()/10);
46+
DROP TEMPORARY MACRO c;
47+
48+
CREATE TEMPORARY MACRO fixed_number() 42;
49+
DROP TEMPORARY FUNCTION fixed_number;
50+
DROP TEMPORARY MACRO IF EXISTS fixed_number;
51+
52+
-- invalid queries
53+
CREATE TEMPORARY MACRO simple_add_error(x int) x + y;
54+
CREATE TEMPORARY MACRO simple_add_error(x int, x int) x + y;
55+
CREATE TEMPORARY MACRO simple_add_error(x int) x NOT IN (select c2);
56+
DROP TEMPORARY MACRO SOME_MACRO;

0 commit comments

Comments
 (0)