Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql
import java.sql.Timestamp

import org.apache.spark.{SparkFunSuite, SparkRuntimeException}
import org.apache.spark.sql.catalyst.analysis.ExpressionBuilder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.variant.ParseJson
import org.apache.spark.sql.internal.SqlApiConf
Expand All @@ -46,31 +47,19 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi
*
* @param inputEntry - List of all input entries that need to be generated
* @param collationType - Flag defining collation type to use
* @return
* @return - List of data generated for expression instance creation
*/
def generateData(
inputEntry: Seq[Any],
collationType: CollationType): Seq[Any] = {
inputEntry.map(generateSingleEntry(_, collationType))
}

/**
* Helper function to generate single entry of data as a string.
* @param inputEntry - Single input entry that requires generation
* @param collationType - Flag defining collation type to use
* @return
*/
def generateDataAsStrings(
inputEntry: Seq[AbstractDataType],
collationType: CollationType): Seq[Any] = {
inputEntry.map(generateInputAsString(_, collationType))
}

/**
* Helper function to generate single entry of data.
* @param inputEntry - Single input entry that requires generation
* @param collationType - Flag defining collation type to use
* @return
* @return - Single input entry data
*/
def generateSingleEntry(
inputEntry: Any,
Expand Down Expand Up @@ -100,7 +89,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi
*
* @param inputType - Single input literal type that requires generation
* @param collationType - Flag defining collation type to use
* @return
* @return - Literal/Expression containing expression ready for evaluation
*/
def generateLiterals(
inputType: AbstractDataType,
Expand All @@ -116,6 +105,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi
}
case BooleanType => Literal(true)
case _: DatetimeType => Literal(Timestamp.valueOf("2009-07-30 12:58:59"))
case DecimalType => Literal((new Decimal).set(5))
case _: DecimalType => Literal((new Decimal).set(5))
case _: DoubleType => Literal(5.0)
case IntegerType | NumericType | IntegralType => Literal(5)
Expand Down Expand Up @@ -158,11 +148,15 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi
case MapType =>
val key = generateLiterals(StringTypeAnyCollation, collationType)
val value = generateLiterals(StringTypeAnyCollation, collationType)
Literal.create(Map(key -> value))
CreateMap(Seq(key, value))
case MapType(keyType, valueType, _) =>
val key = generateLiterals(keyType, collationType)
val value = generateLiterals(valueType, collationType)
Literal.create(Map(key -> value))
CreateMap(Seq(key, value))
case AbstractMapType(keyType, valueType) =>
val key = generateLiterals(keyType, collationType)
val value = generateLiterals(valueType, collationType)
CreateMap(Seq(key, value))
case StructType =>
CreateNamedStruct(
Seq(Literal("start"), generateLiterals(StringTypeAnyCollation, collationType),
Expand All @@ -174,7 +168,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi
*
* @param inputType - Single input type that requires generation
* @param collationType - Flag defining collation type to use
* @return
* @return - String representation of a input ready for SQL query
*/
def generateInputAsString(
inputType: AbstractDataType,
Expand All @@ -189,6 +183,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi
}
case BooleanType => "True"
case _: DatetimeType => "date'2016-04-08'"
case DecimalType => "5.0"
case _: DecimalType => "5.0"
case _: DoubleType => "5.0"
case IntegerType | NumericType | IntegralType => "5"
Expand Down Expand Up @@ -221,6 +216,9 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi
case MapType(keyType, valueType, _) =>
"map(" + generateInputAsString(keyType, collationType) + ", " +
generateInputAsString(valueType, collationType) + ")"
case AbstractMapType(keyType, valueType) =>
"map(" + generateInputAsString(keyType, collationType) + ", " +
generateInputAsString(valueType, collationType) + ")"
case StructType =>
"named_struct( 'start', " + generateInputAsString(StringTypeAnyCollation, collationType) +
", 'end', " + generateInputAsString(StringTypeAnyCollation, collationType) + ")"
Expand All @@ -234,7 +232,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi
*
* @param inputType - Single input type that requires generation
* @param collationType - Flag defining collation type to use
* @return
* @return - String representation for SQL query of a inputType
*/
def generateInputTypeAsStrings(
inputType: AbstractDataType,
Expand All @@ -244,6 +242,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi
case BinaryType => "BINARY"
case BooleanType => "BOOLEAN"
case _: DatetimeType => "DATE"
case DecimalType => "DECIMAL(2, 1)"
case _: DecimalType => "DECIMAL(2, 1)"
case _: DoubleType => "DOUBLE"
case IntegerType | NumericType | IntegralType => "INT"
Expand Down Expand Up @@ -275,6 +274,9 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi
case MapType(keyType, valueType, _) =>
"map<" + generateInputTypeAsStrings(keyType, collationType) + ", " +
generateInputTypeAsStrings(valueType, collationType) + ">"
case AbstractMapType(keyType, valueType) =>
"map<" + generateInputTypeAsStrings(keyType, collationType) + ", " +
generateInputTypeAsStrings(valueType, collationType) + ">"
case StructType =>
"struct<start:" + generateInputTypeAsStrings(StringTypeAnyCollation, collationType) +
", end:" +
Expand All @@ -287,7 +289,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi
/**
* Helper function to extract types of relevance
* @param inputType
* @return
* @return - Boolean that represents if inputType has/is a StringType
*/
def hasStringType(inputType: AbstractDataType): Boolean = {
inputType match {
Expand All @@ -300,7 +302,6 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi
case AbstractArrayType(elementType) => hasStringType(elementType)
case TypeCollection(typeCollection) =>
typeCollection.exists(hasStringType)
case StructType => true
case StructType(fields) => fields.exists(sf => hasStringType(sf.dataType))
case _ => false
}
Expand All @@ -310,7 +311,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi
* Helper function to replace expected parameters with expected input types.
* @param inputTypes - Input types generated by ExpectsInputType.inputTypes
* @param params - Parameters that are read from expression info
* @return
* @return - List of parameters where Expressions are replaced with input types
*/
def replaceExpressions(inputTypes: Seq[AbstractDataType], params: Seq[Class[_]]): Seq[Any] = {
(inputTypes, params) match {
Expand All @@ -325,7 +326,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi

/**
* Helper method to extract relevant expressions that can be walked over.
* @return
* @return - (List of relevant expressions that expect input, List of expressions to skip)
*/
def extractRelevantExpressions(): (Array[ExpressionInfo], List[String]) = {
var expressionCounter = 0
Expand Down Expand Up @@ -384,6 +385,47 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi
(funInfos, toSkip)
}

/**
* Helper method to extract relevant expressions that can be walked over but are built with
* expression builder.
*
* @return - (List of expressions that are relevant builders, List of expressions to skip)
*/
def extractRelevantBuilders(): (Array[ExpressionInfo], List[String]) = {
var builderExpressionCounter = 0
val funInfos = spark.sessionState.functionRegistry.listFunction().map { funcId =>
spark.sessionState.catalog.lookupFunctionInfo(funcId)
}.filter(funInfo => {
// make sure that there is a constructor.
val cl = Utils.classForName(funInfo.getClassName)
cl.isAssignableFrom(classOf[ExpressionBuilder])
}).filter(funInfo => {
builderExpressionCounter = builderExpressionCounter + 1
val cl = Utils.classForName(funInfo.getClassName)
val method = cl.getMethod("build",
Utils.classForName("java.lang.String"),
Utils.classForName("scala.collection.Seq"))
var input: Seq[Expression] = Seq.empty
var i = 0
for (_ <- 1 to 10) {
input = input :+ generateLiterals(StringTypeAnyCollation, Utf8Binary)
try {
method.invoke(null, funInfo.getClassName, input).asInstanceOf[ExpectsInputTypes]
}
catch {
case _: Exception => i = i + 1
}
}
if (i == 10) false
else true
}).toArray

logInfo("Total number of expression that are built: " + builderExpressionCounter)
logInfo("Number of extracted expressions of relevance: " + funInfos.length)

(funInfos, List())
}

/**
* Helper function to generate string of an expression suitable for execution.
* @param expr - Expression that needs to be converted
Expand Down Expand Up @@ -441,10 +483,36 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi
* 5) Otherwise, check if exceptions are the same
*/
test("SPARK-48280: Expression Walker for expression evaluation") {
val (funInfos, toSkip) = extractRelevantExpressions()
val (funInfosExpr, toSkip) = extractRelevantExpressions()
val (funInfosBuild, _) = extractRelevantBuilders()
val funInfos = funInfosExpr ++ funInfosBuild

for (f <- funInfos.filter(f => !toSkip.contains(f.getName))) {
val cl = Utils.classForName(f.getClassName)
val TempCl = Utils.classForName(f.getClassName)
val cl = if (TempCl.isAssignableFrom(classOf[ExpressionBuilder])) {
val clTemp = Utils.classForName(f.getClassName)
val method = clTemp.getMethod("build",
Utils.classForName("java.lang.String"),
Utils.classForName("scala.collection.Seq"))
val instance = {
var input: Seq[Expression] = Seq.empty
var result: Expression = null
for (_ <- 1 to 10) {
input = input :+ generateLiterals(StringTypeAnyCollation, Utf8Binary)
try {
val tempResult = method.invoke(null, f.getClassName, input)
if (result == null) result = tempResult.asInstanceOf[Expression]
}
catch {
case _: Exception =>
}
}
result
}
instance.getClass
}
else Utils.classForName(f.getClassName)

val headConstructor = cl.getConstructors
.zip(cl.getConstructors.map(c => c.getParameters.length)).minBy(a => a._2)._1
val params = headConstructor.getParameters.map(p => p.getType)
Expand Down Expand Up @@ -526,10 +594,36 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi
* 5) Otherwise, check if exceptions are the same
*/
test("SPARK-48280: Expression Walker for codeGen generation") {
val (funInfos, toSkip) = extractRelevantExpressions()
val (funInfosExpr, toSkip) = extractRelevantExpressions()
val (funInfosBuild, _) = extractRelevantBuilders()
val funInfos = funInfosExpr ++ funInfosBuild

for (f <- funInfos.filter(f => !toSkip.contains(f.getName))) {
val cl = Utils.classForName(f.getClassName)
val TempCl = Utils.classForName(f.getClassName)
val cl = if (TempCl.isAssignableFrom(classOf[ExpressionBuilder])) {
val clTemp = Utils.classForName(f.getClassName)
val method = clTemp.getMethod("build",
Utils.classForName("java.lang.String"),
Utils.classForName("scala.collection.Seq"))
val instance = {
var input: Seq[Expression] = Seq.empty
var result: Expression = null
for (_ <- 1 to 10) {
input = input :+ generateLiterals(StringTypeAnyCollation, Utf8Binary)
try {
val tempResult = method.invoke(null, f.getClassName, input)
if (result == null) result = tempResult.asInstanceOf[Expression]
}
catch {
case _: Exception =>
}
}
result
}
instance.getClass
}
else Utils.classForName(f.getClassName)

val headConstructor = cl.getConstructors
.zip(cl.getConstructors.map(c => c.getParameters.length)).minBy(a => a._2)._1
val params = headConstructor.getParameters.map(p => p.getType)
Expand Down