Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
29fdaba
expressionWithAlias for First, Last, StddevSamp, VarianceSamp
amanomer Dec 9, 2019
7d5f4be
Fixed errors
amanomer Dec 9, 2019
7ba7802
+constructor First
amanomer Dec 10, 2019
759262d
+constructor Last
amanomer Dec 10, 2019
5ec101f
Fixed build error
amanomer Dec 10, 2019
f210bb9
Fixed TC
amanomer Dec 10, 2019
cc234b9
ScalaStyle Fix
amanomer Dec 10, 2019
7617ec3
Fixed TC
amanomer Dec 11, 2019
92e381b
Reduce constructors
amanomer Dec 11, 2019
a2d75de
nit
amanomer Dec 11, 2019
3ef62af
Removed unnecessary changes
amanomer Dec 11, 2019
87c6ea3
Fixed TC
amanomer Dec 11, 2019
9480532
Override flatArguments in VarianceSamp, StddevSamp
amanomer Dec 11, 2019
5c540fc
override flatArguments in First, Last
amanomer Dec 12, 2019
d780dfc
override flatArguments in BoolAnd, BoolOr
amanomer Dec 12, 2019
85d9597
add assert()
amanomer Dec 12, 2019
a71e8a7
expressionWithAlias for Average, ApproximatePercentile & override nod…
amanomer Dec 13, 2019
dd2d85d
Fixes for latest update
amanomer Dec 13, 2019
c1b3afb
Fix ApproximatePercentile TC
amanomer Dec 14, 2019
e7a4e90
UT fix
amanomer Dec 14, 2019
ca886f0
nit
amanomer Dec 18, 2019
125cfac
expressionWithTreeNodeTag for ApproximatePercentile
amanomer Dec 18, 2019
4ca20f4
expressionWithTreeNodeTag for BoolAnd, BoolOr, StddevSamp and Varianc…
amanomer Dec 18, 2019
aecdd8a
expressionWithTreeNodeTag for First, Last and Average
amanomer Dec 18, 2019
bbd4397
Renaming to expressionWithTNT
amanomer Dec 18, 2019
9146913
nit
amanomer Dec 18, 2019
8e9e42b
Avoid duplicate code
amanomer Dec 18, 2019
36418e2
small fix
amanomer Dec 18, 2019
ce8ea17
move FUNC_ALIAS to FunctionRegistry
amanomer Dec 19, 2019
4b536dd
Remove expressionWithAlias
amanomer Dec 19, 2019
737f33a
revert reorder
amanomer Dec 19, 2019
1920940
override prettyName instead of nodeName
amanomer Dec 19, 2019
700a84d
Merge branch 'master' into fncAlias
amanomer Dec 19, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import javax.annotation.concurrent.GuardedBy

import scala.collection.mutable
import scala.reflect.ClassTag
import scala.util.{Failure, Success, Try}

import org.apache.spark.internal.Logging
import org.apache.spark.sql.AnalysisException
Expand All @@ -31,6 +30,7 @@ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.xml._
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.spark.sql.types._


Expand Down Expand Up @@ -193,6 +193,8 @@ object FunctionRegistry {

type FunctionBuilder = Seq[Expression] => Expression

val FUNC_ALIAS = TreeNodeTag[String]("functionAliasName")

// Note: Whenever we add a new entry here, make sure we also update ExpressionToSQLSuite
val expressions: Map[String, (ExpressionInfo, FunctionBuilder)] = Map(
// misc non-aggregate functions
Expand Down Expand Up @@ -289,35 +291,35 @@ object FunctionRegistry {
expression[CovPopulation]("covar_pop"),
expression[CovSample]("covar_samp"),
expression[First]("first"),
expression[First]("first_value"),
expression[First]("first_value", true),
expression[Kurtosis]("kurtosis"),
expression[Last]("last"),
expression[Last]("last_value"),
expression[Last]("last_value", true),
expression[Max]("max"),
expression[MaxBy]("max_by"),
expression[Average]("mean"),
expression[Average]("mean", true),
expression[Min]("min"),
expression[MinBy]("min_by"),
expression[Percentile]("percentile"),
expression[Skewness]("skewness"),
expression[ApproximatePercentile]("percentile_approx"),
expression[ApproximatePercentile]("approx_percentile"),
expression[StddevSamp]("std"),
expression[StddevSamp]("stddev"),
expression[ApproximatePercentile]("approx_percentile", true),
expression[StddevSamp]("std", true),
expression[StddevSamp]("stddev", true),
expression[StddevPop]("stddev_pop"),
expression[StddevSamp]("stddev_samp"),
expression[Sum]("sum"),
expression[VarianceSamp]("variance"),
expression[VarianceSamp]("variance", true),
expression[VariancePop]("var_pop"),
expression[VarianceSamp]("var_samp"),
expression[CollectList]("collect_list"),
expression[CollectSet]("collect_set"),
expression[CountMinSketchAgg]("count_min_sketch"),
expressionWithAlias[BoolAnd]("every"),
expressionWithAlias[BoolAnd]("bool_and"),
expressionWithAlias[BoolOr]("any"),
expressionWithAlias[BoolOr]("some"),
expressionWithAlias[BoolOr]("bool_or"),
expression[BoolAnd]("every", true),
expression[BoolAnd]("bool_and"),
expression[BoolOr]("any", true),
expression[BoolOr]("some", true),
expression[BoolOr]("bool_or"),

// string functions
expression[Ascii]("ascii"),
Expand Down Expand Up @@ -573,7 +575,7 @@ object FunctionRegistry {
val functionSet: Set[FunctionIdentifier] = builtin.listFunction().toSet

/** See usage above. */
private def expression[T <: Expression](name: String)
private def expression[T <: Expression](name: String, setAlias: Boolean = false)
(implicit tag: ClassTag[T]): (String, (ExpressionInfo, FunctionBuilder)) = {

// For `RuntimeReplaceable`, skip the constructor with most arguments, which is the main
Expand Down Expand Up @@ -619,7 +621,9 @@ object FunctionRegistry {
throw new AnalysisException(invalidArgumentsMsg)
}
try {
f.newInstance(expressions : _*).asInstanceOf[Expression]
val exp = f.newInstance(expressions : _*).asInstanceOf[Expression]
if (setAlias) exp.setTagValue(FUNC_ALIAS, name)
exp
} catch {
// the exception is an invocation exception. To get a meaningful message, we need the
// cause.
Expand All @@ -631,42 +635,6 @@ object FunctionRegistry {
(name, (expressionInfo[T](name), builder))
}

private def expressionWithAlias[T <: Expression](name: String)
(implicit tag: ClassTag[T]): (String, (ExpressionInfo, FunctionBuilder)) = {
val constructors = tag.runtimeClass.getConstructors
.filter(_.getParameterTypes.head == classOf[String])
assert(constructors.length == 1)
val builder = (expressions: Seq[Expression]) => {
val params = classOf[String] +: Seq.fill(expressions.size)(classOf[Expression])
val f = constructors.find(_.getParameterTypes.toSeq == params).getOrElse {
val validParametersCount = constructors
.filter(_.getParameterTypes.tail.forall(_ == classOf[Expression]))
.map(_.getParameterCount - 1).distinct.sorted
val invalidArgumentsMsg = if (validParametersCount.length == 0) {
s"Invalid arguments for function $name"
} else {
val expectedNumberOfParameters = if (validParametersCount.length == 1) {
validParametersCount.head.toString
} else {
validParametersCount.init.mkString("one of ", ", ", " and ") +
validParametersCount.last
}
s"Invalid number of arguments for function $name. " +
s"Expected: $expectedNumberOfParameters; Found: ${expressions.size}"
}
throw new AnalysisException(invalidArgumentsMsg)
}
try {
f.newInstance(name.toString +: expressions: _*).asInstanceOf[Expression]
} catch {
// the exception is an invocation exception. To get a meaningful message, we need the
// cause.
case e: Exception => throw new AnalysisException(e.getCause.getMessage)
}
}
(name, (expressionInfo[T](name), builder))
}

/**
* Creates a function registry lookup entry for cast aliases (SPARK-16730).
* For example, if name is "int", and dataType is IntegerType, this means int(x) would become
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import java.nio.ByteBuffer
import com.google.common.primitives.{Doubles, Ints, Longs}

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.PercentileDigest
Expand Down Expand Up @@ -185,7 +185,8 @@ case class ApproximatePercentile(
if (returnPercentileArray) ArrayType(child.dataType, false) else child.dataType
}

override def prettyName: String = "percentile_approx"
override def prettyName: String =
getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("percentile_approx")

override def serialize(obj: PercentileDigest): Array[Byte] = {
ApproximatePercentile.serializer.serialize(obj)
Expand Down Expand Up @@ -321,4 +322,5 @@ object ApproximatePercentile {
}

val serializer: PercentileDigestSerializer = new PercentileDigestSerializer

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.catalyst.expressions.aggregate

import org.apache.spark.sql.catalyst.analysis.{DecimalPrecision, TypeCheckResult}
import org.apache.spark.sql.catalyst.analysis.{DecimalPrecision, FunctionRegistry, TypeCheckResult}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.TypeUtils
Expand All @@ -37,7 +37,7 @@ import org.apache.spark.sql.types._
since = "1.0.0")
case class Average(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes {

override def prettyName: String = "avg"
override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("avg")

override def children: Seq[Expression] = child :: Nil

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.catalyst.expressions.aggregate

import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -174,7 +175,8 @@ case class StddevSamp(child: Expression) extends CentralMomentAgg(child) {
If(n === 1.0, Double.NaN, sqrt(m2 / (n - 1.0))))
}

override def prettyName: String = "stddev_samp"
override def prettyName: String =
getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("stddev_samp")
}

// Compute the population variance of a column
Expand Down Expand Up @@ -215,7 +217,7 @@ case class VarianceSamp(child: Expression) extends CentralMomentAgg(child) {
If(n === 1.0, Double.NaN, m2 / (n - 1.0)))
}

override def prettyName: String = "var_samp"
override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("var_samp")
}

@ExpressionDescription(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.catalyst.expressions.aggregate

import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
Expand Down Expand Up @@ -113,5 +113,7 @@ case class First(child: Expression, ignoreNullsExpr: Expression)

override lazy val evaluateExpression: AttributeReference = first

override def toString: String = s"first($child)${if (ignoreNulls) " ignore nulls"}"
override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("first")

override def toString: String = s"$prettyName($child)${if (ignoreNulls) " ignore nulls"}"
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.catalyst.expressions.aggregate

import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
Expand Down Expand Up @@ -111,5 +111,7 @@ case class Last(child: Expression, ignoreNullsExpr: Expression)

override lazy val evaluateExpression: AttributeReference = last

override def toString: String = s"last($child)${if (ignoreNulls) " ignore nulls"}"
override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("last")

override def toString: String = s"$prettyName($child)${if (ignoreNulls) " ignore nulls"}"
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.catalyst.expressions.aggregate

import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -52,8 +52,8 @@ abstract class UnevaluableBooleanAggBase(arg: Expression)
false
""",
since = "3.0.0")
case class BoolAnd(funcName: String, arg: Expression) extends UnevaluableBooleanAggBase(arg) {
override def nodeName: String = funcName
case class BoolAnd(arg: Expression) extends UnevaluableBooleanAggBase(arg) {
override def nodeName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("bool_and")
}

@ExpressionDescription(
Expand All @@ -68,6 +68,6 @@ case class BoolAnd(funcName: String, arg: Expression) extends UnevaluableBoolean
false
""",
since = "3.0.0")
case class BoolOr(funcName: String, arg: Expression) extends UnevaluableBooleanAggBase(arg) {
override def nodeName: String = funcName
case class BoolOr(arg: Expression) extends UnevaluableBooleanAggBase(arg) {
override def nodeName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("bool_or")
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ object ReplaceExpressions extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case e: RuntimeReplaceable => e.child
case CountIf(predicate) => Count(new NullIf(predicate, Literal.FalseLiteral))
case BoolOr(_, arg) => Max(arg)
case BoolAnd(_, arg) => Min(arg)
case BoolOr(arg) => Max(arg)
case BoolAnd(arg) => Min(arg)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,8 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
assertSuccess(Sum(Symbol("stringField")))
assertSuccess(Average(Symbol("stringField")))
assertSuccess(Min(Symbol("arrayField")))
assertSuccess(new BoolAnd("bool_and", Symbol("booleanField")))
assertSuccess(new BoolOr("bool_or", Symbol("booleanField")))
assertSuccess(new BoolAnd(Symbol("booleanField")))
assertSuccess(new BoolOr(Symbol("booleanField")))

assertError(Min(Symbol("mapField")), "min does not support ordering on type")
assertError(Max(Symbol("mapField")), "max does not support ordering on type")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ NULL 1
SELECT SKEWNESS(a), KURTOSIS(a), MIN(a), MAX(a), AVG(a), VARIANCE(a), STDDEV(a), SUM(a), COUNT(a)
FROM testData
-- !query 13 schema
struct<skewness(CAST(a AS DOUBLE)):double,kurtosis(CAST(a AS DOUBLE)):double,min(a):int,max(a):int,avg(a):double,var_samp(CAST(a AS DOUBLE)):double,stddev_samp(CAST(a AS DOUBLE)):double,sum(a):bigint,count(a):bigint>
struct<skewness(CAST(a AS DOUBLE)):double,kurtosis(CAST(a AS DOUBLE)):double,min(a):int,max(a):int,avg(a):double,variance(CAST(a AS DOUBLE)):double,stddev(CAST(a AS DOUBLE)):double,sum(a):bigint,count(a):bigint>
-- !query 13 output
-0.2723801058145729 -1.5069204152249134 1 3 2.142857142857143 0.8095238095238094 0.8997354108424372 15 7

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ NaN
SELECT VARIANCE(n) OVER (ORDER BY i ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING)
FROM (VALUES(1,600),(2,470),(3,170),(4,430),(5,300)) r(i,n)
-- !query 19 schema
struct<var_samp(CAST(n AS DOUBLE)) OVER (ORDER BY i ASC NULLS FIRST ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING):double>
struct<variance(CAST(n AS DOUBLE)) OVER (ORDER BY i ASC NULLS FIRST ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING):double>
-- !query 19 output
16900.0
18491.666666666668
Expand All @@ -254,7 +254,7 @@ NaN
SELECT VARIANCE(n) OVER (ORDER BY i ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING)
FROM (VALUES(1,600),(2,470),(3,170),(4,430),(5,300)) r(i,n)
-- !query 20 schema
struct<var_samp(CAST(n AS DOUBLE)) OVER (ORDER BY i ASC NULLS FIRST ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING):double>
struct<variance(CAST(n AS DOUBLE)) OVER (ORDER BY i ASC NULLS FIRST ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING):double>
-- !query 20 output
16900.0
18491.666666666668
Expand All @@ -267,7 +267,7 @@ NaN
SELECT VARIANCE(n) OVER (ORDER BY i ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING)
FROM (VALUES(1,600),(2,470),(3,170),(4,430),(5,300)) r(i,n)
-- !query 21 schema
struct<var_samp(CAST(n AS DOUBLE)) OVER (ORDER BY i ASC NULLS FIRST ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING):double>
struct<variance(CAST(n AS DOUBLE)) OVER (ORDER BY i ASC NULLS FIRST ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING):double>
-- !query 21 output
16900.0
18491.666666666668
Expand All @@ -280,7 +280,7 @@ NaN
SELECT VARIANCE(n) OVER (ORDER BY i ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING)
FROM (VALUES(1,600),(2,470),(3,170),(4,430),(5,300)) r(i,n)
-- !query 22 schema
struct<var_samp(CAST(n AS DOUBLE)) OVER (ORDER BY i ASC NULLS FIRST ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING):double>
struct<variance(CAST(n AS DOUBLE)) OVER (ORDER BY i ASC NULLS FIRST ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING):double>
-- !query 22 output
16900.0
18491.666666666668
Expand Down Expand Up @@ -405,7 +405,7 @@ NaN
SELECT STDDEV(n) OVER (ORDER BY i ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING)
FROM (VALUES(0,NULL),(1,600),(2,470),(3,170),(4,430),(5,300)) r(i,n)
-- !query 31 schema
struct<stddev_samp(CAST(n AS DOUBLE)) OVER (ORDER BY i ASC NULLS FIRST ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING):double>
struct<stddev(CAST(n AS DOUBLE)) OVER (ORDER BY i ASC NULLS FIRST ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING):double>
-- !query 31 output
130.0
135.9840676942217
Expand All @@ -419,7 +419,7 @@ NaN
SELECT STDDEV(n) OVER (ORDER BY i ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING)
FROM (VALUES(0,NULL),(1,600),(2,470),(3,170),(4,430),(5,300)) r(i,n)
-- !query 32 schema
struct<stddev_samp(CAST(n AS DOUBLE)) OVER (ORDER BY i ASC NULLS FIRST ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING):double>
struct<stddev(CAST(n AS DOUBLE)) OVER (ORDER BY i ASC NULLS FIRST ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING):double>
-- !query 32 output
130.0
135.9840676942217
Expand All @@ -433,7 +433,7 @@ NaN
SELECT STDDEV(n) OVER (ORDER BY i ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING)
FROM (VALUES(0,NULL),(1,600),(2,470),(3,170),(4,430),(5,300)) r(i,n)
-- !query 33 schema
struct<stddev_samp(CAST(n AS DOUBLE)) OVER (ORDER BY i ASC NULLS FIRST ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING):double>
struct<stddev(CAST(n AS DOUBLE)) OVER (ORDER BY i ASC NULLS FIRST ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING):double>
-- !query 33 output
130.0
135.9840676942217
Expand All @@ -447,7 +447,7 @@ NaN
SELECT STDDEV(n) OVER (ORDER BY i ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING)
FROM (VALUES(0,NULL),(1,600),(2,470),(3,170),(4,430),(5,300)) r(i,n)
-- !query 34 schema
struct<stddev_samp(CAST(n AS DOUBLE)) OVER (ORDER BY i ASC NULLS FIRST ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING):double>
struct<stddev(CAST(n AS DOUBLE)) OVER (ORDER BY i ASC NULLS FIRST ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING):double>
-- !query 34 output
130.0
135.9840676942217
Expand Down
Loading