diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index a21c3bbb9626d..240c01a17e6ca 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -412,6 +412,8 @@ object SparkBuild extends PomBuild { /* Hive console settings */ enable(Hive.settings)(hive) + enable(HiveThriftServer.settings)(hiveThriftServer) + enable(SparkConnectCommon.settings)(connectCommon) enable(SparkConnect.settings)(connect) enable(SparkConnectClient.settings)(connectClient) @@ -1203,6 +1205,14 @@ object Hive { ) } +object HiveThriftServer { + lazy val settings = Seq( + excludeDependencies ++= Seq( + ExclusionRule("org.apache.hive", "hive-llap-common"), + ExclusionRule("org.apache.hive", "hive-llap-client")) + ) +} + object YARN { val genConfigProperties = TaskKey[Unit]("gen-config-properties", "Generate config.properties which contains a setting whether Hadoop is provided or not") diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index 8ecc5b56ca0ea..3eef95f34f031 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -148,16 +148,6 @@ byte-buddy-agent test - - ${hive.group} - hive-llap-common - ${hive.llap.scope} - - - ${hive.group} - hive-llap-client - ${hive.llap.scope} - net.sf.jpam jpam diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java index 4b55453ec7a8b..9d9b6f1c7b0e1 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java @@ -673,15 +673,23 @@ public void close() throws HiveSQLException { hiveHist.closeStream(); } try { + // Forcibly initialize thread local Hive so that + // SessionState#unCacheDataNucleusClassLoaders won't trigger + // Hive built-in UDFs initialization. + Hive.getWithoutRegisterFns(sessionState.getConf()); sessionState.close(); } finally { sessionState = null; } - } catch (IOException ioe) { + } catch (IOException | HiveException ioe) { throw new HiveSQLException("Failure to close", ioe); } finally { if (sessionState != null) { try { + // Forcibly initialize thread local Hive so that + // SessionState#unCacheDataNucleusClassLoaders won't trigger + // Hive built-in UDFs initialization. + Hive.getWithoutRegisterFns(sessionState.getConf()); sessionState.close(); } catch (Throwable t) { LOG.warn("Error closing session", t); diff --git a/sql/hive-thriftserver/src/test/resources/log4j2.properties b/sql/hive-thriftserver/src/test/resources/log4j2.properties index e6753047c9055..207fd3c22ab93 100644 --- a/sql/hive-thriftserver/src/test/resources/log4j2.properties +++ b/sql/hive-thriftserver/src/test/resources/log4j2.properties @@ -92,3 +92,6 @@ logger.parquet2.level = error logger.thriftserver.name = org.apache.spark.sql.hive.thriftserver.SparkExecuteStatementOperation logger.thriftserver.level = off + +logger.dagscheduler.name = org.apache.spark.scheduler.DAGScheduler +logger.dagscheduler.level = error diff --git a/sql/hive/src/main/java/org/apache/hadoop/hive/ql/exec/HiveFunctionRegistryUtils.java b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/exec/HiveFunctionRegistryUtils.java new file mode 100644 index 0000000000000..333ae0151a5a6 --- /dev/null +++ b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/exec/HiveFunctionRegistryUtils.java @@ -0,0 +1,342 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.exec; + +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils.PrimitiveGrouping; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; +import org.apache.hadoop.hive.serde2.typeinfo.*; + +import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; + +import org.apache.spark.internal.SparkLogger; +import org.apache.spark.internal.SparkLoggerFactory; + +/** + * Copy some methods from {@link org.apache.hadoop.hive.ql.exec.FunctionRegistry} + * to avoid initializing Hive built-in UDFs. + *

+ * The code is based on Hive 2.3.10. + */ +public class HiveFunctionRegistryUtils { + + public static final SparkLogger LOG = + SparkLoggerFactory.getLogger(HiveFunctionRegistryUtils.class); + + /** + * This method is shared between UDFRegistry and UDAFRegistry. methodName will + * be "evaluate" for UDFRegistry, and "aggregate"/"evaluate"/"evaluatePartial" + * for UDAFRegistry. + * @throws UDFArgumentException + */ + public static Method getMethodInternal(Class udfClass, + String methodName, boolean exact, List argumentClasses) + throws UDFArgumentException { + + List mlist = new ArrayList<>(); + + for (Method m : udfClass.getMethods()) { + if (m.getName().equals(methodName)) { + mlist.add(m); + } + } + + return getMethodInternal(udfClass, mlist, exact, argumentClasses); + } + + /** + * Gets the closest matching method corresponding to the argument list from a + * list of methods. + * + * @param mlist + * The list of methods to inspect. + * @param exact + * Boolean to indicate whether this is an exact match or not. + * @param argumentsPassed + * The classes for the argument. + * @return The matching method. + */ + public static Method getMethodInternal(Class udfClass, List mlist, boolean exact, + List argumentsPassed) throws UDFArgumentException { + + // result + List udfMethods = new ArrayList<>(); + // The cost of the result + int leastConversionCost = Integer.MAX_VALUE; + + for (Method m : mlist) { + List argumentsAccepted = TypeInfoUtils.getParameterTypeInfos(m, + argumentsPassed.size()); + if (argumentsAccepted == null) { + // null means the method does not accept number of arguments passed. + continue; + } + + boolean match = (argumentsAccepted.size() == argumentsPassed.size()); + int conversionCost = 0; + + for (int i = 0; i < argumentsPassed.size() && match; i++) { + int cost = matchCost(argumentsPassed.get(i), argumentsAccepted.get(i), exact); + if (cost == -1) { + match = false; + } else { + conversionCost += cost; + } + } + + LOG.debug("Method {} match: passed = {} accepted = {} method = {}", + match ? "did" : "didn't", argumentsPassed, argumentsAccepted, m); + + if (match) { + // Always choose the function with least implicit conversions. + if (conversionCost < leastConversionCost) { + udfMethods.clear(); + udfMethods.add(m); + leastConversionCost = conversionCost; + // Found an exact match + if (leastConversionCost == 0) { + break; + } + } else if (conversionCost == leastConversionCost) { + // Ambiguous call: two methods with the same number of implicit + // conversions + udfMethods.add(m); + // Don't break! We might find a better match later. + } else { + // do nothing if implicitConversions > leastImplicitConversions + } + } + } + + if (udfMethods.size() == 0) { + // No matching methods found + throw new NoMatchingMethodException(udfClass, argumentsPassed, mlist); + } + + if (udfMethods.size() > 1) { + // First try selecting methods based on the type affinity of the arguments passed + // to the candidate method arguments. + filterMethodsByTypeAffinity(udfMethods, argumentsPassed); + } + + if (udfMethods.size() > 1) { + + // if the only difference is numeric types, pick the method + // with the smallest overall numeric type. + int lowestNumericType = Integer.MAX_VALUE; + boolean multiple = true; + Method candidate = null; + List referenceArguments = null; + + for (Method m: udfMethods) { + int maxNumericType = 0; + + List argumentsAccepted = + TypeInfoUtils.getParameterTypeInfos(m, argumentsPassed.size()); + + if (referenceArguments == null) { + // keep the arguments for reference - we want all the non-numeric + // arguments to be the same + referenceArguments = argumentsAccepted; + } + + Iterator referenceIterator = referenceArguments.iterator(); + + for (TypeInfo accepted: argumentsAccepted) { + TypeInfo reference = referenceIterator.next(); + + boolean acceptedIsPrimitive = false; + PrimitiveCategory acceptedPrimCat = PrimitiveCategory.UNKNOWN; + if (accepted.getCategory() == Category.PRIMITIVE) { + acceptedIsPrimitive = true; + acceptedPrimCat = ((PrimitiveTypeInfo) accepted).getPrimitiveCategory(); + } + if (acceptedIsPrimitive && TypeInfoUtils.numericTypes.containsKey(acceptedPrimCat)) { + // We're looking for the udf with the smallest maximum numeric type. + int typeValue = TypeInfoUtils.numericTypes.get(acceptedPrimCat); + maxNumericType = typeValue > maxNumericType ? typeValue : maxNumericType; + } else if (!accepted.equals(reference)) { + // There are non-numeric arguments that don't match from one UDF to + // another. We give up at this point. + throw new AmbiguousMethodException(udfClass, argumentsPassed, mlist); + } + } + + if (lowestNumericType > maxNumericType) { + multiple = false; + lowestNumericType = maxNumericType; + candidate = m; + } else if (maxNumericType == lowestNumericType) { + // multiple udfs with the same max type. Unless we find a lower one + // we'll give up. + multiple = true; + } + } + + if (!multiple) { + return candidate; + } else { + throw new AmbiguousMethodException(udfClass, argumentsPassed, mlist); + } + } + return udfMethods.get(0); + } + + public static Object invoke(Method m, Object thisObject, Object... arguments) + throws HiveException { + Object o; + try { + o = m.invoke(thisObject, arguments); + } catch (Exception e) { + StringBuilder argumentString = new StringBuilder(); + if (arguments == null) { + argumentString.append("null"); + } else { + argumentString.append("{"); + for (int i = 0; i < arguments.length; i++) { + if (i > 0) { + argumentString.append(","); + } + + argumentString.append(arguments[i]); + } + argumentString.append("}"); + } + + String detailedMsg = e instanceof java.lang.reflect.InvocationTargetException ? + e.getCause().getMessage() : e.getMessage(); + + throw new HiveException("Unable to execute method " + m + " with arguments " + + argumentString + ":" + detailedMsg, e); + } + return o; + } + + /** + * Returns -1 if passed does not match accepted. Otherwise return the cost + * (usually 0 for no conversion and 1 for conversion). + */ + public static int matchCost(TypeInfo argumentPassed, + TypeInfo argumentAccepted, boolean exact) { + if (argumentAccepted.equals(argumentPassed) + || TypeInfoUtils.doPrimitiveCategoriesMatch(argumentPassed, argumentAccepted)) { + // matches + return 0; + } + if (argumentPassed.equals(TypeInfoFactory.voidTypeInfo)) { + // passing null matches everything + return 0; + } + if (argumentPassed.getCategory().equals(Category.LIST) + && argumentAccepted.getCategory().equals(Category.LIST)) { + // lists are compatible if and only-if the elements are compatible + TypeInfo argumentPassedElement = ((ListTypeInfo) argumentPassed) + .getListElementTypeInfo(); + TypeInfo argumentAcceptedElement = ((ListTypeInfo) argumentAccepted) + .getListElementTypeInfo(); + return matchCost(argumentPassedElement, argumentAcceptedElement, exact); + } + if (argumentPassed.getCategory().equals(Category.MAP) + && argumentAccepted.getCategory().equals(Category.MAP)) { + // lists are compatible if and only-if the elements are compatible + TypeInfo argumentPassedKey = ((MapTypeInfo) argumentPassed) + .getMapKeyTypeInfo(); + TypeInfo argumentAcceptedKey = ((MapTypeInfo) argumentAccepted) + .getMapKeyTypeInfo(); + TypeInfo argumentPassedValue = ((MapTypeInfo) argumentPassed) + .getMapValueTypeInfo(); + TypeInfo argumentAcceptedValue = ((MapTypeInfo) argumentAccepted) + .getMapValueTypeInfo(); + int cost1 = matchCost(argumentPassedKey, argumentAcceptedKey, exact); + int cost2 = matchCost(argumentPassedValue, argumentAcceptedValue, exact); + if (cost1 < 0 || cost2 < 0) { + return -1; + } + return Math.max(cost1, cost2); + } + + if (argumentAccepted.equals(TypeInfoFactory.unknownTypeInfo)) { + // accepting Object means accepting everything, + // but there is a conversion cost. + return 1; + } + if (!exact && TypeInfoUtils.implicitConvertible(argumentPassed, argumentAccepted)) { + return 1; + } + + return -1; + } + + /** + * Given a set of candidate methods and list of argument types, try to + * select the best candidate based on how close the passed argument types are + * to the candidate argument types. + * For a varchar argument, we would prefer evaluate(string) over evaluate(double). + * @param udfMethods list of candidate methods + * @param argumentsPassed list of argument types to match to the candidate methods + */ + static void filterMethodsByTypeAffinity(List udfMethods, List argumentsPassed) { + if (udfMethods.size() > 1) { + // Prefer methods with a closer signature based on the primitive grouping of each argument. + // Score each method based on its similarity to the passed argument types. + int currentScore = 0; + int bestMatchScore = 0; + Method bestMatch = null; + for (Method m: udfMethods) { + currentScore = 0; + List argumentsAccepted = + TypeInfoUtils.getParameterTypeInfos(m, argumentsPassed.size()); + Iterator argsPassedIter = argumentsPassed.iterator(); + for (TypeInfo acceptedType : argumentsAccepted) { + // Check the affinity of the argument passed in with the accepted argument, + // based on the PrimitiveGrouping + TypeInfo passedType = argsPassedIter.next(); + if (acceptedType.getCategory() == Category.PRIMITIVE + && passedType.getCategory() == Category.PRIMITIVE) { + PrimitiveGrouping acceptedPg = PrimitiveObjectInspectorUtils.getPrimitiveGrouping( + ((PrimitiveTypeInfo) acceptedType).getPrimitiveCategory()); + PrimitiveGrouping passedPg = PrimitiveObjectInspectorUtils.getPrimitiveGrouping( + ((PrimitiveTypeInfo) passedType).getPrimitiveCategory()); + if (acceptedPg == passedPg) { + // The passed argument matches somewhat closely with an accepted argument + ++currentScore; + } + } + } + // Check if the score for this method is any better relative to others + if (currentScore > bestMatchScore) { + bestMatchScore = currentScore; + bestMatch = m; + } else if (currentScore == bestMatchScore) { + bestMatch = null; // no longer a best match if more than one. + } + } + + if (bestMatch != null) { + // Found a best match during this processing, use it. + udfMethods.clear(); + udfMethods.add(bestMatch); + } + } + } +} diff --git a/sql/hive/src/main/java/org/apache/hadoop/hive/ql/exec/SparkDefaultUDAFEvaluatorResolver.java b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/exec/SparkDefaultUDAFEvaluatorResolver.java new file mode 100644 index 0000000000000..8cbc5a2cc96be --- /dev/null +++ b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/exec/SparkDefaultUDAFEvaluatorResolver.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.exec; + +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; + +import java.lang.reflect.Field; +import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.List; + +/** + * A equivalent implementation of {@link DefaultUDAFEvaluatorResolver}, but eliminate calls + * of {@link FunctionRegistry} to avoid initializing Hive built-in UDFs. + *

+ * The code is based on Hive 2.3.10. + */ +@SuppressWarnings("deprecation") +public class SparkDefaultUDAFEvaluatorResolver implements UDAFEvaluatorResolver { + + /** + * The class of the UDAF. + */ + private final Class udafClass; + + /** + * Constructor. This constructor extract udafClass from {@link DefaultUDAFEvaluatorResolver} + */ + @SuppressWarnings("unchecked") + public SparkDefaultUDAFEvaluatorResolver(DefaultUDAFEvaluatorResolver wrapped) { + try { + Field udfClassField = wrapped.getClass().getDeclaredField("udafClass"); + udfClassField.setAccessible(true); + this.udafClass = (Class) udfClassField.get(wrapped); + } catch (ReflectiveOperationException rethrow) { + throw new RuntimeException(rethrow); + } + } + + /** + * Gets the evaluator class for the UDAF given the parameter types. + * + * @param argClasses + * The list of the parameter types. + */ + @SuppressWarnings("unchecked") + public Class getEvaluatorClass( + List argClasses) throws UDFArgumentException { + + ArrayList> classList = new ArrayList<>(); + + // Add all the public member classes that implement an evaluator + for (Class enclClass : udafClass.getClasses()) { + if (UDAFEvaluator.class.isAssignableFrom(enclClass)) { + classList.add((Class) enclClass); + } + } + + // Next we locate all the iterate methods for each of these classes. + ArrayList mList = new ArrayList<>(); + ArrayList> cList = new ArrayList<>(); + for (Class evaluator : classList) { + for (Method m : evaluator.getMethods()) { + if (m.getName().equalsIgnoreCase("iterate")) { + mList.add(m); + cList.add(evaluator); + } + } + } + + Method m = HiveFunctionRegistryUtils.getMethodInternal(udafClass, mList, false, argClasses); + + // Find the class that has this method. + // Note that Method.getDeclaringClass() may not work here because the method + // can be inherited from a base class. + int found = -1; + for (int i = 0; i < mList.size(); i++) { + if (mList.get(i) == m) { + if (found == -1) { + found = i; + } else { + throw new AmbiguousMethodException(udafClass, argClasses, mList); + } + } + } + assert (found != -1); + + return cList.get(found); + } +} diff --git a/sql/hive/src/main/java/org/apache/hadoop/hive/ql/exec/SparkDefaultUDFMethodResolver.java b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/exec/SparkDefaultUDFMethodResolver.java new file mode 100644 index 0000000000000..1c2cbb108e1b1 --- /dev/null +++ b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/exec/SparkDefaultUDFMethodResolver.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.exec; + +import org.apache.hadoop.hive.serde2.typeinfo.*; + +import java.lang.reflect.Field; +import java.lang.reflect.Method; +import java.util.List; + +/** + * A equivalent implementation of {@link DefaultUDFMethodResolver}, but eliminate calls + * of {@link org.apache.hadoop.hive.ql.exec.FunctionRegistry} to avoid initializing Hive + * built-in UDFs. + *

+ * The code is based on Hive 2.3.10. + */ +public class SparkDefaultUDFMethodResolver implements UDFMethodResolver { + + /** + * The class of the UDF. + */ + private final Class udfClass; + + /** + * Constructor. This constructor extract udfClass from {@link DefaultUDFMethodResolver} + */ + @SuppressWarnings("unchecked") + public SparkDefaultUDFMethodResolver(DefaultUDFMethodResolver wrapped) { + try { + Field udfClassField = wrapped.getClass().getDeclaredField("udfClass"); + udfClassField.setAccessible(true); + this.udfClass = (Class) udfClassField.get(wrapped); + } catch (ReflectiveOperationException rethrow) { + throw new RuntimeException(rethrow); + } + } + + /** + * Gets the evaluate method for the UDF given the parameter types. + * + * @param argClasses + * The list of the argument types that need to matched with the + * evaluate function signature. + */ + @Override + public Method getEvalMethod(List argClasses) throws UDFArgumentException { + return HiveFunctionRegistryUtils.getMethodInternal(udfClass, "evaluate", false, argClasses); + } +} diff --git a/sql/hive/src/main/java/org/apache/hadoop/hive/ql/udf/generic/SparkGenericUDAFBridge.java b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/udf/generic/SparkGenericUDAFBridge.java new file mode 100644 index 0000000000000..72a758de90a26 --- /dev/null +++ b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/udf/generic/SparkGenericUDAFBridge.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.udf.generic; + +import java.io.Serial; +import java.util.Arrays; + +import org.apache.hadoop.hive.ql.exec.DefaultUDAFEvaluatorResolver; +import org.apache.hadoop.hive.ql.exec.FunctionRegistry; +import org.apache.hadoop.hive.ql.exec.HiveFunctionRegistryUtils; +import org.apache.hadoop.hive.ql.exec.SparkDefaultUDAFEvaluatorResolver; +import org.apache.hadoop.hive.ql.exec.UDAF; +import org.apache.hadoop.hive.ql.exec.UDAFEvaluator; +import org.apache.hadoop.hive.ql.exec.UDAFEvaluatorResolver; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.parse.SemanticException; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; + +/** + * A equivalent implementation of {@link GenericUDAFBridge}, but eliminate calls + * of {@link FunctionRegistry} to avoid initializing Hive built-in UDFs. + *

+ * The code is based on Hive 2.3.10. + */ +@SuppressWarnings("deprecation") +public class SparkGenericUDAFBridge extends GenericUDAFBridge { + + public SparkGenericUDAFBridge(UDAF udaf) { + super(udaf); + } + + @Override + public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters) throws SemanticException { + + UDAFEvaluatorResolver resolver = udaf.getResolver(); + if (resolver instanceof DefaultUDAFEvaluatorResolver) { + resolver = new SparkDefaultUDAFEvaluatorResolver((DefaultUDAFEvaluatorResolver) resolver); + } + Class udafEvaluatorClass = + resolver.getEvaluatorClass(Arrays.asList(parameters)); + + return new SparkGenericUDAFBridgeEvaluator(udafEvaluatorClass); + } + + public static class SparkGenericUDAFBridgeEvaluator + extends GenericUDAFBridge.GenericUDAFBridgeEvaluator { + + @Serial + private static final long serialVersionUID = 1L; + + // Used by serialization only + public SparkGenericUDAFBridgeEvaluator() { + } + + public SparkGenericUDAFBridgeEvaluator( + Class udafEvaluator) { + super(udafEvaluator); + } + + @Override + public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException { + HiveFunctionRegistryUtils.invoke(iterateMethod, ((UDAFAgg) agg).ueObject, + conversionHelper.convertIfNecessary(parameters)); + } + + @Override + public void merge(AggregationBuffer agg, Object partial) throws HiveException { + HiveFunctionRegistryUtils.invoke(mergeMethod, ((UDAFAgg) agg).ueObject, + conversionHelper.convertIfNecessary(partial)); + } + + @Override + public Object terminate(AggregationBuffer agg) throws HiveException { + return HiveFunctionRegistryUtils.invoke(terminateMethod, ((UDAFAgg) agg).ueObject); + } + + @Override + public Object terminatePartial(AggregationBuffer agg) throws HiveException { + return HiveFunctionRegistryUtils.invoke(terminatePartialMethod, ((UDAFAgg) agg).ueObject); + } + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFEvaluators.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFEvaluators.scala index 409be67f7af4c..866d88dee8783 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFEvaluators.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFEvaluators.scala @@ -17,14 +17,17 @@ package org.apache.spark.sql.hive +import java.lang.reflect.Method + import scala.jdk.CollectionConverters._ -import org.apache.hadoop.hive.ql.exec.{FunctionRegistry, UDF} +import org.apache.hadoop.hive.ql.exec.{DefaultUDFMethodResolver, HiveFunctionRegistryUtils, SparkDefaultUDFMethodResolver, UDF, UDFArgumentException} +import org.apache.hadoop.hive.ql.metadata.HiveException import org.apache.hadoop.hive.ql.udf.{UDFType => HiveUDFType} import org.apache.hadoop.hive.ql.udf.generic.GenericUDF import org.apache.hadoop.hive.ql.udf.generic.GenericUDF._ import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils.ConversionHelper -import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorFactory} +import org.apache.hadoop.hive.serde2.objectinspector.{ConstantObjectInspector, ObjectInspector, ObjectInspectorFactory, ObjectInspectorUtils} import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.ObjectInspectorOptions import org.apache.spark.sql.catalyst.expressions.Expression @@ -70,8 +73,10 @@ class HiveSimpleUDFEvaluator( extends HiveUDFEvaluatorBase[UDF](funcWrapper, children) { @transient - lazy val method = function.getResolver. - getEvalMethod(children.map(_.dataType.toTypeInfo).asJava) + lazy val method: Method = (function.getResolver match { + case r: DefaultUDFMethodResolver => new SparkDefaultUDFMethodResolver(r) + case r => r + }).getEvalMethod(children.map(_.dataType.toTypeInfo).asJava) @transient private lazy val wrappers = children.map(x => wrapperFor(toInspector(x), x.dataType)).toArray @@ -98,7 +103,7 @@ class HiveSimpleUDFEvaluator( method.getGenericReturnType, ObjectInspectorOptions.JAVA)) override def doEvaluate(): Any = { - val ret = FunctionRegistry.invoke( + val ret = HiveFunctionRegistryUtils.invoke( method, function, conversionHelper.convertIfNecessary(inputs: _*): _*) @@ -111,17 +116,41 @@ class HiveGenericUDFEvaluator( extends HiveUDFEvaluatorBase[GenericUDF](funcWrapper, children) { @transient - private lazy val argumentInspectors = children.map(toInspector) + private lazy val argumentInspectors = children.map(toInspector).toArray @transient lazy val returnInspector = { - function.initializeAndFoldConstants(argumentInspectors.toArray) + // Inline o.a.h.hive.ql.udf.generic.GenericUDF#initializeAndFoldConstants, but + // eliminate calls o.a.h.hive.ql.exec.FunctionRegistry to avoid initializing Hive + // built-in UDFs. + val oi = function.initialize(argumentInspectors) + // If the UDF depends on any external resources, we can't fold because the + // resources may not be available at compile time. + if (function.getRequiredFiles == null && function.getRequiredJars == null && + argumentInspectors.forall(ObjectInspectorUtils.isConstantObjectInspector) && + !ObjectInspectorUtils.isConstantObjectInspector(oi) && + isUDFDeterministic && + ObjectInspectorUtils.supportsConstantObjectInspector(oi)) { + val argumentValues: Array[DeferredObject] = argumentInspectors.map { argumentInspector => + new GenericUDF.DeferredJavaObject( + argumentInspector.asInstanceOf[ConstantObjectInspector].getWritableConstantValue) + } + try { + val constantValue = function.evaluate(argumentValues) + ObjectInspectorUtils.getConstantObjectInspector(oi, constantValue) + } catch { + case e: HiveException => + throw new UDFArgumentException(e) + } + } else { + oi + } } @transient private lazy val deferredObjects: Array[DeferredObject] = argumentInspectors.zip(children).map { case (inspect, child) => new DeferredObjectAdapter(inspect, child.dataType) - }.toArray[DeferredObject] + } @transient private lazy val unwrapper: Any => Any = unwrapperFor(returnInspector) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index 227c6a618e3d4..bf708eecf0c0c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -353,7 +353,7 @@ private[hive] case class HiveUDAFFunction( private def newEvaluator(): GenericUDAFEvaluator = { val resolver = if (isUDAFBridgeRequired) { - new GenericUDAFBridge(funcWrapper.createFunction[UDAF]()) + new SparkGenericUDAFBridge(funcWrapper.createFunction[UDAF]()) } else { funcWrapper.createFunction[AbstractGenericUDAFResolver]() }