From ec30d943ebe0f34b760d63637ffab292442be3c4 Mon Sep 17 00:00:00 2001 From: Cheng Pan Date: Tue, 11 Mar 2025 11:41:59 +0800 Subject: [PATCH 1/6] [SPARK-51466][SQL][HIVE] Eliminate Hive built-in UDFs initialization on Hive UDF evaluation --- .../exec/SparkDefaultUDFMethodResolver.java | 341 ++++++++++++++++++ .../spark/sql/hive/hiveUDFEvaluators.scala | 63 +++- 2 files changed, 393 insertions(+), 11 deletions(-) create mode 100644 sql/hive/src/main/java/org/apache/hadoop/hive/ql/exec/SparkDefaultUDFMethodResolver.java diff --git a/sql/hive/src/main/java/org/apache/hadoop/hive/ql/exec/SparkDefaultUDFMethodResolver.java b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/exec/SparkDefaultUDFMethodResolver.java new file mode 100644 index 0000000000000..34a6c683833ce --- /dev/null +++ b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/exec/SparkDefaultUDFMethodResolver.java @@ -0,0 +1,341 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.exec; + +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils.PrimitiveGrouping; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; +import org.apache.hadoop.hive.serde2.typeinfo.*; + +import java.lang.reflect.Field; +import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; + +import org.apache.spark.internal.SparkLogger; +import org.apache.spark.internal.SparkLoggerFactory; + +/** + * A equivalent implementation of {@link DefaultUDFMethodResolver}, but eliminate calls + * of {@link org.apache.hadoop.hive.ql.exec.FunctionRegistry} to avoid initializing Hive + * built-in UDFs. + */ +public class SparkDefaultUDFMethodResolver implements UDFMethodResolver { + + public static final SparkLogger LOG = + SparkLoggerFactory.getLogger(SparkDefaultUDFMethodResolver.class); + + /** + * The class of the UDF. + */ + private final Class udfClass; + + /** + * Constructor. This constructor extract udfClass from {@link DefaultUDFMethodResolver} + */ + @SuppressWarnings("unchecked") + public SparkDefaultUDFMethodResolver( + DefaultUDFMethodResolver wrapped) throws ReflectiveOperationException { + Field udfClassField = wrapped.getClass().getDeclaredField("udfClass"); + udfClassField.setAccessible(true); + this.udfClass = (Class) udfClassField.get(wrapped); + } + + /** + * Gets the evaluate method for the UDF given the parameter types. + * + * @param argClasses + * The list of the argument types that need to matched with the + * evaluate function signature. + */ + @Override + public Method getEvalMethod(List argClasses) throws UDFArgumentException { + return getMethodInternal(udfClass, "evaluate", false, argClasses); + } + + // Below methods are copied from Hive 2.3.10 o.a.h.hive.ql.exec.FunctionRegistry + + /** + * This method is shared between UDFRegistry and UDAFRegistry. methodName will + * be "evaluate" for UDFRegistry, and "aggregate"/"evaluate"/"evaluatePartial" + * for UDAFRegistry. + * @throws UDFArgumentException + */ + public static Method getMethodInternal(Class udfClass, + String methodName, boolean exact, List argumentClasses) + throws UDFArgumentException { + + List mlist = new ArrayList<>(); + + for (Method m : udfClass.getMethods()) { + if (m.getName().equals(methodName)) { + mlist.add(m); + } + } + + return getMethodInternal(udfClass, mlist, exact, argumentClasses); + } + + /** + * Gets the closest matching method corresponding to the argument list from a + * list of methods. + * + * @param mlist + * The list of methods to inspect. + * @param exact + * Boolean to indicate whether this is an exact match or not. + * @param argumentsPassed + * The classes for the argument. + * @return The matching method. + */ + public static Method getMethodInternal(Class udfClass, List mlist, boolean exact, + List argumentsPassed) throws UDFArgumentException { + + // result + List udfMethods = new ArrayList<>(); + // The cost of the result + int leastConversionCost = Integer.MAX_VALUE; + + for (Method m : mlist) { + List argumentsAccepted = TypeInfoUtils.getParameterTypeInfos(m, + argumentsPassed.size()); + if (argumentsAccepted == null) { + // null means the method does not accept number of arguments passed. + continue; + } + + boolean match = (argumentsAccepted.size() == argumentsPassed.size()); + int conversionCost = 0; + + for (int i = 0; i < argumentsPassed.size() && match; i++) { + int cost = matchCost(argumentsPassed.get(i), argumentsAccepted.get(i), exact); + if (cost == -1) { + match = false; + } else { + conversionCost += cost; + } + } + if (LOG.isDebugEnabled()) { + LOG.debug("Method {} match: passed = {} accepted = {} method = {}", + match ? "did" : "didn't", argumentsPassed, argumentsAccepted, m); + } + if (match) { + // Always choose the function with least implicit conversions. + if (conversionCost < leastConversionCost) { + udfMethods.clear(); + udfMethods.add(m); + leastConversionCost = conversionCost; + // Found an exact match + if (leastConversionCost == 0) { + break; + } + } else if (conversionCost == leastConversionCost) { + // Ambiguous call: two methods with the same number of implicit + // conversions + udfMethods.add(m); + // Don't break! We might find a better match later. + } else { + // do nothing if implicitConversions > leastImplicitConversions + } + } + } + + if (udfMethods.size() == 0) { + // No matching methods found + throw new NoMatchingMethodException(udfClass, argumentsPassed, mlist); + } + + if (udfMethods.size() > 1) { + // First try selecting methods based on the type affinity of the arguments passed + // to the candidate method arguments. + filterMethodsByTypeAffinity(udfMethods, argumentsPassed); + } + + if (udfMethods.size() > 1) { + + // if the only difference is numeric types, pick the method + // with the smallest overall numeric type. + int lowestNumericType = Integer.MAX_VALUE; + boolean multiple = true; + Method candidate = null; + List referenceArguments = null; + + for (Method m: udfMethods) { + int maxNumericType = 0; + + List argumentsAccepted = + TypeInfoUtils.getParameterTypeInfos(m, argumentsPassed.size()); + + if (referenceArguments == null) { + // keep the arguments for reference - we want all the non-numeric + // arguments to be the same + referenceArguments = argumentsAccepted; + } + + Iterator referenceIterator = referenceArguments.iterator(); + + for (TypeInfo accepted: argumentsAccepted) { + TypeInfo reference = referenceIterator.next(); + + boolean acceptedIsPrimitive = false; + PrimitiveCategory acceptedPrimCat = PrimitiveCategory.UNKNOWN; + if (accepted.getCategory() == Category.PRIMITIVE) { + acceptedIsPrimitive = true; + acceptedPrimCat = ((PrimitiveTypeInfo) accepted).getPrimitiveCategory(); + } + if (acceptedIsPrimitive && TypeInfoUtils.numericTypes.containsKey(acceptedPrimCat)) { + // We're looking for the udf with the smallest maximum numeric type. + int typeValue = TypeInfoUtils.numericTypes.get(acceptedPrimCat); + maxNumericType = typeValue > maxNumericType ? typeValue : maxNumericType; + } else if (!accepted.equals(reference)) { + // There are non-numeric arguments that don't match from one UDF to + // another. We give up at this point. + throw new AmbiguousMethodException(udfClass, argumentsPassed, mlist); + } + } + + if (lowestNumericType > maxNumericType) { + multiple = false; + lowestNumericType = maxNumericType; + candidate = m; + } else if (maxNumericType == lowestNumericType) { + // multiple udfs with the same max type. Unless we find a lower one + // we'll give up. + multiple = true; + } + } + + if (!multiple) { + return candidate; + } else { + throw new AmbiguousMethodException(udfClass, argumentsPassed, mlist); + } + } + return udfMethods.get(0); + } + + /** + * Returns -1 if passed does not match accepted. Otherwise return the cost + * (usually 0 for no conversion and 1 for conversion). + */ + public static int matchCost(TypeInfo argumentPassed, + TypeInfo argumentAccepted, boolean exact) { + if (argumentAccepted.equals(argumentPassed) + || TypeInfoUtils.doPrimitiveCategoriesMatch(argumentPassed, argumentAccepted)) { + // matches + return 0; + } + if (argumentPassed.equals(TypeInfoFactory.voidTypeInfo)) { + // passing null matches everything + return 0; + } + if (argumentPassed.getCategory().equals(Category.LIST) + && argumentAccepted.getCategory().equals(Category.LIST)) { + // lists are compatible if and only-if the elements are compatible + TypeInfo argumentPassedElement = ((ListTypeInfo) argumentPassed) + .getListElementTypeInfo(); + TypeInfo argumentAcceptedElement = ((ListTypeInfo) argumentAccepted) + .getListElementTypeInfo(); + return matchCost(argumentPassedElement, argumentAcceptedElement, exact); + } + if (argumentPassed.getCategory().equals(Category.MAP) + && argumentAccepted.getCategory().equals(Category.MAP)) { + // lists are compatible if and only-if the elements are compatible + TypeInfo argumentPassedKey = ((MapTypeInfo) argumentPassed) + .getMapKeyTypeInfo(); + TypeInfo argumentAcceptedKey = ((MapTypeInfo) argumentAccepted) + .getMapKeyTypeInfo(); + TypeInfo argumentPassedValue = ((MapTypeInfo) argumentPassed) + .getMapValueTypeInfo(); + TypeInfo argumentAcceptedValue = ((MapTypeInfo) argumentAccepted) + .getMapValueTypeInfo(); + int cost1 = matchCost(argumentPassedKey, argumentAcceptedKey, exact); + int cost2 = matchCost(argumentPassedValue, argumentAcceptedValue, exact); + if (cost1 < 0 || cost2 < 0) { + return -1; + } + return Math.max(cost1, cost2); + } + + if (argumentAccepted.equals(TypeInfoFactory.unknownTypeInfo)) { + // accepting Object means accepting everything, + // but there is a conversion cost. + return 1; + } + if (!exact && TypeInfoUtils.implicitConvertible(argumentPassed, argumentAccepted)) { + return 1; + } + + return -1; + } + + /** + * Given a set of candidate methods and list of argument types, try to + * select the best candidate based on how close the passed argument types are + * to the candidate argument types. + * For a varchar argument, we would prefer evaluate(string) over evaluate(double). + * @param udfMethods list of candidate methods + * @param argumentsPassed list of argument types to match to the candidate methods + */ + static void filterMethodsByTypeAffinity(List udfMethods, List argumentsPassed) { + if (udfMethods.size() > 1) { + // Prefer methods with a closer signature based on the primitive grouping of each argument. + // Score each method based on its similarity to the passed argument types. + int currentScore = 0; + int bestMatchScore = 0; + Method bestMatch = null; + for (Method m: udfMethods) { + currentScore = 0; + List argumentsAccepted = + TypeInfoUtils.getParameterTypeInfos(m, argumentsPassed.size()); + Iterator argsPassedIter = argumentsPassed.iterator(); + for (TypeInfo acceptedType : argumentsAccepted) { + // Check the affinity of the argument passed in with the accepted argument, + // based on the PrimitiveGrouping + TypeInfo passedType = argsPassedIter.next(); + if (acceptedType.getCategory() == Category.PRIMITIVE + && passedType.getCategory() == Category.PRIMITIVE) { + PrimitiveGrouping acceptedPg = PrimitiveObjectInspectorUtils.getPrimitiveGrouping( + ((PrimitiveTypeInfo) acceptedType).getPrimitiveCategory()); + PrimitiveGrouping passedPg = PrimitiveObjectInspectorUtils.getPrimitiveGrouping( + ((PrimitiveTypeInfo) passedType).getPrimitiveCategory()); + if (acceptedPg == passedPg) { + // The passed argument matches somewhat closely with an accepted argument + ++currentScore; + } + } + } + // Check if the score for this method is any better relative to others + if (currentScore > bestMatchScore) { + bestMatchScore = currentScore; + bestMatch = m; + } else if (currentScore == bestMatchScore) { + bestMatch = null; // no longer a best match if more than one. + } + } + + if (bestMatch != null) { + // Found a best match during this processing, use it. + udfMethods.clear(); + udfMethods.add(bestMatch); + } + } + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFEvaluators.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFEvaluators.scala index 409be67f7af4c..26fac1b155617 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFEvaluators.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFEvaluators.scala @@ -17,14 +17,17 @@ package org.apache.spark.sql.hive +import java.lang.reflect.{InvocationTargetException, Method} + import scala.jdk.CollectionConverters._ -import org.apache.hadoop.hive.ql.exec.{FunctionRegistry, UDF} +import org.apache.hadoop.hive.ql.exec.{DefaultUDFMethodResolver, SparkDefaultUDFMethodResolver, UDF, UDFArgumentException} +import org.apache.hadoop.hive.ql.metadata.HiveException import org.apache.hadoop.hive.ql.udf.{UDFType => HiveUDFType} import org.apache.hadoop.hive.ql.udf.generic.GenericUDF import org.apache.hadoop.hive.ql.udf.generic.GenericUDF._ import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils.ConversionHelper -import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorFactory} +import org.apache.hadoop.hive.serde2.objectinspector.{ConstantObjectInspector, ObjectInspector, ObjectInspectorFactory, ObjectInspectorUtils} import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.ObjectInspectorOptions import org.apache.spark.sql.catalyst.expressions.Expression @@ -70,8 +73,10 @@ class HiveSimpleUDFEvaluator( extends HiveUDFEvaluatorBase[UDF](funcWrapper, children) { @transient - lazy val method = function.getResolver. - getEvalMethod(children.map(_.dataType.toTypeInfo).asJava) + lazy val method: Method = (function.getResolver match { + case r: DefaultUDFMethodResolver => new SparkDefaultUDFMethodResolver(r) + case r => r + }).getEvalMethod(children.map(_.dataType.toTypeInfo).asJava) @transient private lazy val wrappers = children.map(x => wrapperFor(toInspector(x), x.dataType)).toArray @@ -98,10 +103,22 @@ class HiveSimpleUDFEvaluator( method.getGenericReturnType, ObjectInspectorOptions.JAVA)) override def doEvaluate(): Any = { - val ret = FunctionRegistry.invoke( - method, - function, - conversionHelper.convertIfNecessary(inputs: _*): _*) + val arguments = conversionHelper.convertIfNecessary(inputs: _*) + // Follow behavior of o.a.h.hive.ql.exec.FunctionRegistry#invoke + val ret = try { + method.invoke(function, arguments: _*) + } catch { + case e: Exception => + val argumentString = + if (arguments == null) "null" else arguments.mkString("{", ",", "}") + val detailedMsg = if (e.isInstanceOf[InvocationTargetException]) { + e.getCause.getMessage + } else { + e.getMessage + } + throw new HiveException( + s"Unable to execute method $method with arguments $argumentString:$detailedMsg", e) + } unwrapper(ret) } } @@ -111,17 +128,41 @@ class HiveGenericUDFEvaluator( extends HiveUDFEvaluatorBase[GenericUDF](funcWrapper, children) { @transient - private lazy val argumentInspectors = children.map(toInspector) + private lazy val argumentInspectors = children.map(toInspector).toArray @transient lazy val returnInspector = { - function.initializeAndFoldConstants(argumentInspectors.toArray) + // Inline o.a.h.hive.ql.udf.generic.GenericUDF#initializeAndFoldConstants, but + // elminate calls o.a.h.hive.ql.exec.FunctionRegistry to avoid initializing Hive + // built-in UDFs. + val oi = function.initialize(argumentInspectors) + // If the UDF depends on any external resources, we can't fold because the + // resources may not be available at compile time. + if (function.getRequiredFiles == null && function.getRequiredJars == null && + argumentInspectors.forall(ObjectInspectorUtils.isConstantObjectInspector) && + !ObjectInspectorUtils.isConstantObjectInspector(oi) && + isUDFDeterministic && + ObjectInspectorUtils.supportsConstantObjectInspector(oi)) { + val argumentValues: Array[DeferredObject] = argumentInspectors.map { argumentInspector => + new GenericUDF.DeferredJavaObject( + argumentInspector.asInstanceOf[ConstantObjectInspector].getWritableConstantValue) + } + try { + val constantValue = function.evaluate(argumentValues) + ObjectInspectorUtils.getConstantObjectInspector(oi, constantValue) + } catch { + case e: HiveException => + throw new UDFArgumentException(e) + } + } else { + oi + } } @transient private lazy val deferredObjects: Array[DeferredObject] = argumentInspectors.zip(children).map { case (inspect, child) => new DeferredObjectAdapter(inspect, child.dataType) - }.toArray[DeferredObject] + } @transient private lazy val unwrapper: Any => Any = unwrapperFor(returnInspector) From 30950d6c70e72241fac63f2477ea82bb86e37bd4 Mon Sep 17 00:00:00 2001 From: Cheng Pan Date: Tue, 11 Mar 2025 21:01:09 +0800 Subject: [PATCH 2/6] Fix STS --- sql/hive-thriftserver/pom.xml | 20 +- .../service/cli/session/HiveSessionImpl.java | 10 +- .../src/test/resources/log4j2.properties | 3 + .../ql/exec/HiveFunctionRegistryUtils.java | 342 ++++++++++++++++++ .../SparkDefaultUDAFEvaluatorResolver.java | 105 ++++++ .../exec/SparkDefaultUDFMethodResolver.java | 298 +-------------- .../udf/generic/SparkGenericUDAFBridge.java | 194 ++++++++++ .../spark/sql/hive/hiveUDFEvaluators.scala | 24 +- .../org/apache/spark/sql/hive/hiveUDFs.scala | 2 +- 9 files changed, 681 insertions(+), 317 deletions(-) create mode 100644 sql/hive/src/main/java/org/apache/hadoop/hive/ql/exec/HiveFunctionRegistryUtils.java create mode 100644 sql/hive/src/main/java/org/apache/hadoop/hive/ql/exec/SparkDefaultUDAFEvaluatorResolver.java create mode 100644 sql/hive/src/main/java/org/apache/hadoop/hive/ql/udf/generic/SparkGenericUDAFBridge.java diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index e57fa5a235420..4d80247666e2a 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -148,16 +148,16 @@ byte-buddy-agent test - - ${hive.group} - hive-llap-common - ${hive.llap.scope} - - - ${hive.group} - hive-llap-client - ${hive.llap.scope} - + + + + + + + + + + net.sf.jpam jpam diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java index 4b55453ec7a8b..9d9b6f1c7b0e1 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java @@ -673,15 +673,23 @@ public void close() throws HiveSQLException { hiveHist.closeStream(); } try { + // Forcibly initialize thread local Hive so that + // SessionState#unCacheDataNucleusClassLoaders won't trigger + // Hive built-in UDFs initialization. + Hive.getWithoutRegisterFns(sessionState.getConf()); sessionState.close(); } finally { sessionState = null; } - } catch (IOException ioe) { + } catch (IOException | HiveException ioe) { throw new HiveSQLException("Failure to close", ioe); } finally { if (sessionState != null) { try { + // Forcibly initialize thread local Hive so that + // SessionState#unCacheDataNucleusClassLoaders won't trigger + // Hive built-in UDFs initialization. + Hive.getWithoutRegisterFns(sessionState.getConf()); sessionState.close(); } catch (Throwable t) { LOG.warn("Error closing session", t); diff --git a/sql/hive-thriftserver/src/test/resources/log4j2.properties b/sql/hive-thriftserver/src/test/resources/log4j2.properties index e6753047c9055..207fd3c22ab93 100644 --- a/sql/hive-thriftserver/src/test/resources/log4j2.properties +++ b/sql/hive-thriftserver/src/test/resources/log4j2.properties @@ -92,3 +92,6 @@ logger.parquet2.level = error logger.thriftserver.name = org.apache.spark.sql.hive.thriftserver.SparkExecuteStatementOperation logger.thriftserver.level = off + +logger.dagscheduler.name = org.apache.spark.scheduler.DAGScheduler +logger.dagscheduler.level = error diff --git a/sql/hive/src/main/java/org/apache/hadoop/hive/ql/exec/HiveFunctionRegistryUtils.java b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/exec/HiveFunctionRegistryUtils.java new file mode 100644 index 0000000000000..9e3bf24977ecd --- /dev/null +++ b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/exec/HiveFunctionRegistryUtils.java @@ -0,0 +1,342 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.exec; + +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils.PrimitiveGrouping; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; +import org.apache.hadoop.hive.serde2.typeinfo.*; + +import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; + +import org.apache.spark.internal.SparkLogger; +import org.apache.spark.internal.SparkLoggerFactory; + +/** + * Copy some methods from {@link org.apache.hadoop.hive.ql.exec.FunctionRegistry} + * to avoid initializing Hive built-in UDFs. + *

+ * The code is based on Hive 2.3.10. + */ +public class HiveFunctionRegistryUtils { + + public static final SparkLogger LOG = + SparkLoggerFactory.getLogger(HiveFunctionRegistryUtils.class); + + /** + * This method is shared between UDFRegistry and UDAFRegistry. methodName will + * be "evaluate" for UDFRegistry, and "aggregate"/"evaluate"/"evaluatePartial" + * for UDAFRegistry. + * @throws UDFArgumentException + */ + public static Method getMethodInternal(Class udfClass, + String methodName, boolean exact, List argumentClasses) + throws UDFArgumentException { + + List mlist = new ArrayList<>(); + + for (Method m : udfClass.getMethods()) { + if (m.getName().equals(methodName)) { + mlist.add(m); + } + } + + return getMethodInternal(udfClass, mlist, exact, argumentClasses); + } + + /** + * Gets the closest matching method corresponding to the argument list from a + * list of methods. + * + * @param mlist + * The list of methods to inspect. + * @param exact + * Boolean to indicate whether this is an exact match or not. + * @param argumentsPassed + * The classes for the argument. + * @return The matching method. + */ + public static Method getMethodInternal(Class udfClass, List mlist, boolean exact, + List argumentsPassed) throws UDFArgumentException { + + // result + List udfMethods = new ArrayList<>(); + // The cost of the result + int leastConversionCost = Integer.MAX_VALUE; + + for (Method m : mlist) { + List argumentsAccepted = TypeInfoUtils.getParameterTypeInfos(m, + argumentsPassed.size()); + if (argumentsAccepted == null) { + // null means the method does not accept number of arguments passed. + continue; + } + + boolean match = (argumentsAccepted.size() == argumentsPassed.size()); + int conversionCost = 0; + + for (int i = 0; i < argumentsPassed.size() && match; i++) { + int cost = matchCost(argumentsPassed.get(i), argumentsAccepted.get(i), exact); + if (cost == -1) { + match = false; + } else { + conversionCost += cost; + } + } + if (LOG.isDebugEnabled()) { + LOG.debug("Method {} match: passed = {} accepted = {} method = {}", + match ? "did" : "didn't", argumentsPassed, argumentsAccepted, m); + } + if (match) { + // Always choose the function with least implicit conversions. + if (conversionCost < leastConversionCost) { + udfMethods.clear(); + udfMethods.add(m); + leastConversionCost = conversionCost; + // Found an exact match + if (leastConversionCost == 0) { + break; + } + } else if (conversionCost == leastConversionCost) { + // Ambiguous call: two methods with the same number of implicit + // conversions + udfMethods.add(m); + // Don't break! We might find a better match later. + } else { + // do nothing if implicitConversions > leastImplicitConversions + } + } + } + + if (udfMethods.size() == 0) { + // No matching methods found + throw new NoMatchingMethodException(udfClass, argumentsPassed, mlist); + } + + if (udfMethods.size() > 1) { + // First try selecting methods based on the type affinity of the arguments passed + // to the candidate method arguments. + filterMethodsByTypeAffinity(udfMethods, argumentsPassed); + } + + if (udfMethods.size() > 1) { + + // if the only difference is numeric types, pick the method + // with the smallest overall numeric type. + int lowestNumericType = Integer.MAX_VALUE; + boolean multiple = true; + Method candidate = null; + List referenceArguments = null; + + for (Method m: udfMethods) { + int maxNumericType = 0; + + List argumentsAccepted = + TypeInfoUtils.getParameterTypeInfos(m, argumentsPassed.size()); + + if (referenceArguments == null) { + // keep the arguments for reference - we want all the non-numeric + // arguments to be the same + referenceArguments = argumentsAccepted; + } + + Iterator referenceIterator = referenceArguments.iterator(); + + for (TypeInfo accepted: argumentsAccepted) { + TypeInfo reference = referenceIterator.next(); + + boolean acceptedIsPrimitive = false; + PrimitiveCategory acceptedPrimCat = PrimitiveCategory.UNKNOWN; + if (accepted.getCategory() == Category.PRIMITIVE) { + acceptedIsPrimitive = true; + acceptedPrimCat = ((PrimitiveTypeInfo) accepted).getPrimitiveCategory(); + } + if (acceptedIsPrimitive && TypeInfoUtils.numericTypes.containsKey(acceptedPrimCat)) { + // We're looking for the udf with the smallest maximum numeric type. + int typeValue = TypeInfoUtils.numericTypes.get(acceptedPrimCat); + maxNumericType = typeValue > maxNumericType ? typeValue : maxNumericType; + } else if (!accepted.equals(reference)) { + // There are non-numeric arguments that don't match from one UDF to + // another. We give up at this point. + throw new AmbiguousMethodException(udfClass, argumentsPassed, mlist); + } + } + + if (lowestNumericType > maxNumericType) { + multiple = false; + lowestNumericType = maxNumericType; + candidate = m; + } else if (maxNumericType == lowestNumericType) { + // multiple udfs with the same max type. Unless we find a lower one + // we'll give up. + multiple = true; + } + } + + if (!multiple) { + return candidate; + } else { + throw new AmbiguousMethodException(udfClass, argumentsPassed, mlist); + } + } + return udfMethods.get(0); + } + + public static Object invoke(Method m, Object thisObject, Object... arguments) + throws HiveException { + Object o; + try { + o = m.invoke(thisObject, arguments); + } catch (Exception e) { + StringBuilder argumentString = new StringBuilder(); + if (arguments == null) { + argumentString.append("null"); + } else { + argumentString.append("{"); + for (int i = 0; i < arguments.length; i++) { + if (i > 0) { + argumentString.append(","); + } + + argumentString.append(arguments[i]); + } + argumentString.append("}"); + } + + String detailedMsg = e instanceof java.lang.reflect.InvocationTargetException ? + e.getCause().getMessage() : e.getMessage(); + + throw new HiveException("Unable to execute method " + m + " with arguments " + + argumentString + ":" + detailedMsg, e); + } + return o; + } + + /** + * Returns -1 if passed does not match accepted. Otherwise return the cost + * (usually 0 for no conversion and 1 for conversion). + */ + public static int matchCost(TypeInfo argumentPassed, + TypeInfo argumentAccepted, boolean exact) { + if (argumentAccepted.equals(argumentPassed) + || TypeInfoUtils.doPrimitiveCategoriesMatch(argumentPassed, argumentAccepted)) { + // matches + return 0; + } + if (argumentPassed.equals(TypeInfoFactory.voidTypeInfo)) { + // passing null matches everything + return 0; + } + if (argumentPassed.getCategory().equals(Category.LIST) + && argumentAccepted.getCategory().equals(Category.LIST)) { + // lists are compatible if and only-if the elements are compatible + TypeInfo argumentPassedElement = ((ListTypeInfo) argumentPassed) + .getListElementTypeInfo(); + TypeInfo argumentAcceptedElement = ((ListTypeInfo) argumentAccepted) + .getListElementTypeInfo(); + return matchCost(argumentPassedElement, argumentAcceptedElement, exact); + } + if (argumentPassed.getCategory().equals(Category.MAP) + && argumentAccepted.getCategory().equals(Category.MAP)) { + // lists are compatible if and only-if the elements are compatible + TypeInfo argumentPassedKey = ((MapTypeInfo) argumentPassed) + .getMapKeyTypeInfo(); + TypeInfo argumentAcceptedKey = ((MapTypeInfo) argumentAccepted) + .getMapKeyTypeInfo(); + TypeInfo argumentPassedValue = ((MapTypeInfo) argumentPassed) + .getMapValueTypeInfo(); + TypeInfo argumentAcceptedValue = ((MapTypeInfo) argumentAccepted) + .getMapValueTypeInfo(); + int cost1 = matchCost(argumentPassedKey, argumentAcceptedKey, exact); + int cost2 = matchCost(argumentPassedValue, argumentAcceptedValue, exact); + if (cost1 < 0 || cost2 < 0) { + return -1; + } + return Math.max(cost1, cost2); + } + + if (argumentAccepted.equals(TypeInfoFactory.unknownTypeInfo)) { + // accepting Object means accepting everything, + // but there is a conversion cost. + return 1; + } + if (!exact && TypeInfoUtils.implicitConvertible(argumentPassed, argumentAccepted)) { + return 1; + } + + return -1; + } + + /** + * Given a set of candidate methods and list of argument types, try to + * select the best candidate based on how close the passed argument types are + * to the candidate argument types. + * For a varchar argument, we would prefer evaluate(string) over evaluate(double). + * @param udfMethods list of candidate methods + * @param argumentsPassed list of argument types to match to the candidate methods + */ + static void filterMethodsByTypeAffinity(List udfMethods, List argumentsPassed) { + if (udfMethods.size() > 1) { + // Prefer methods with a closer signature based on the primitive grouping of each argument. + // Score each method based on its similarity to the passed argument types. + int currentScore = 0; + int bestMatchScore = 0; + Method bestMatch = null; + for (Method m: udfMethods) { + currentScore = 0; + List argumentsAccepted = + TypeInfoUtils.getParameterTypeInfos(m, argumentsPassed.size()); + Iterator argsPassedIter = argumentsPassed.iterator(); + for (TypeInfo acceptedType : argumentsAccepted) { + // Check the affinity of the argument passed in with the accepted argument, + // based on the PrimitiveGrouping + TypeInfo passedType = argsPassedIter.next(); + if (acceptedType.getCategory() == Category.PRIMITIVE + && passedType.getCategory() == Category.PRIMITIVE) { + PrimitiveGrouping acceptedPg = PrimitiveObjectInspectorUtils.getPrimitiveGrouping( + ((PrimitiveTypeInfo) acceptedType).getPrimitiveCategory()); + PrimitiveGrouping passedPg = PrimitiveObjectInspectorUtils.getPrimitiveGrouping( + ((PrimitiveTypeInfo) passedType).getPrimitiveCategory()); + if (acceptedPg == passedPg) { + // The passed argument matches somewhat closely with an accepted argument + ++currentScore; + } + } + } + // Check if the score for this method is any better relative to others + if (currentScore > bestMatchScore) { + bestMatchScore = currentScore; + bestMatch = m; + } else if (currentScore == bestMatchScore) { + bestMatch = null; // no longer a best match if more than one. + } + } + + if (bestMatch != null) { + // Found a best match during this processing, use it. + udfMethods.clear(); + udfMethods.add(bestMatch); + } + } + } +} diff --git a/sql/hive/src/main/java/org/apache/hadoop/hive/ql/exec/SparkDefaultUDAFEvaluatorResolver.java b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/exec/SparkDefaultUDAFEvaluatorResolver.java new file mode 100644 index 0000000000000..8cbc5a2cc96be --- /dev/null +++ b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/exec/SparkDefaultUDAFEvaluatorResolver.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.exec; + +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; + +import java.lang.reflect.Field; +import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.List; + +/** + * A equivalent implementation of {@link DefaultUDAFEvaluatorResolver}, but eliminate calls + * of {@link FunctionRegistry} to avoid initializing Hive built-in UDFs. + *

+ * The code is based on Hive 2.3.10. + */ +@SuppressWarnings("deprecation") +public class SparkDefaultUDAFEvaluatorResolver implements UDAFEvaluatorResolver { + + /** + * The class of the UDAF. + */ + private final Class udafClass; + + /** + * Constructor. This constructor extract udafClass from {@link DefaultUDAFEvaluatorResolver} + */ + @SuppressWarnings("unchecked") + public SparkDefaultUDAFEvaluatorResolver(DefaultUDAFEvaluatorResolver wrapped) { + try { + Field udfClassField = wrapped.getClass().getDeclaredField("udafClass"); + udfClassField.setAccessible(true); + this.udafClass = (Class) udfClassField.get(wrapped); + } catch (ReflectiveOperationException rethrow) { + throw new RuntimeException(rethrow); + } + } + + /** + * Gets the evaluator class for the UDAF given the parameter types. + * + * @param argClasses + * The list of the parameter types. + */ + @SuppressWarnings("unchecked") + public Class getEvaluatorClass( + List argClasses) throws UDFArgumentException { + + ArrayList> classList = new ArrayList<>(); + + // Add all the public member classes that implement an evaluator + for (Class enclClass : udafClass.getClasses()) { + if (UDAFEvaluator.class.isAssignableFrom(enclClass)) { + classList.add((Class) enclClass); + } + } + + // Next we locate all the iterate methods for each of these classes. + ArrayList mList = new ArrayList<>(); + ArrayList> cList = new ArrayList<>(); + for (Class evaluator : classList) { + for (Method m : evaluator.getMethods()) { + if (m.getName().equalsIgnoreCase("iterate")) { + mList.add(m); + cList.add(evaluator); + } + } + } + + Method m = HiveFunctionRegistryUtils.getMethodInternal(udafClass, mList, false, argClasses); + + // Find the class that has this method. + // Note that Method.getDeclaringClass() may not work here because the method + // can be inherited from a base class. + int found = -1; + for (int i = 0; i < mList.size(); i++) { + if (mList.get(i) == m) { + if (found == -1) { + found = i; + } else { + throw new AmbiguousMethodException(udafClass, argClasses, mList); + } + } + } + assert (found != -1); + + return cList.get(found); + } +} diff --git a/sql/hive/src/main/java/org/apache/hadoop/hive/ql/exec/SparkDefaultUDFMethodResolver.java b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/exec/SparkDefaultUDFMethodResolver.java index 34a6c683833ce..1c2cbb108e1b1 100644 --- a/sql/hive/src/main/java/org/apache/hadoop/hive/ql/exec/SparkDefaultUDFMethodResolver.java +++ b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/exec/SparkDefaultUDFMethodResolver.java @@ -17,31 +17,21 @@ package org.apache.hadoop.hive.ql.exec; -import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils.PrimitiveGrouping; -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; import org.apache.hadoop.hive.serde2.typeinfo.*; import java.lang.reflect.Field; import java.lang.reflect.Method; -import java.util.ArrayList; -import java.util.Iterator; import java.util.List; -import org.apache.spark.internal.SparkLogger; -import org.apache.spark.internal.SparkLoggerFactory; - /** * A equivalent implementation of {@link DefaultUDFMethodResolver}, but eliminate calls * of {@link org.apache.hadoop.hive.ql.exec.FunctionRegistry} to avoid initializing Hive * built-in UDFs. + *

+ * The code is based on Hive 2.3.10. */ public class SparkDefaultUDFMethodResolver implements UDFMethodResolver { - public static final SparkLogger LOG = - SparkLoggerFactory.getLogger(SparkDefaultUDFMethodResolver.class); - /** * The class of the UDF. */ @@ -51,11 +41,14 @@ public class SparkDefaultUDFMethodResolver implements UDFMethodResolver { * Constructor. This constructor extract udfClass from {@link DefaultUDFMethodResolver} */ @SuppressWarnings("unchecked") - public SparkDefaultUDFMethodResolver( - DefaultUDFMethodResolver wrapped) throws ReflectiveOperationException { - Field udfClassField = wrapped.getClass().getDeclaredField("udfClass"); - udfClassField.setAccessible(true); - this.udfClass = (Class) udfClassField.get(wrapped); + public SparkDefaultUDFMethodResolver(DefaultUDFMethodResolver wrapped) { + try { + Field udfClassField = wrapped.getClass().getDeclaredField("udfClass"); + udfClassField.setAccessible(true); + this.udfClass = (Class) udfClassField.get(wrapped); + } catch (ReflectiveOperationException rethrow) { + throw new RuntimeException(rethrow); + } } /** @@ -67,275 +60,6 @@ public SparkDefaultUDFMethodResolver( */ @Override public Method getEvalMethod(List argClasses) throws UDFArgumentException { - return getMethodInternal(udfClass, "evaluate", false, argClasses); - } - - // Below methods are copied from Hive 2.3.10 o.a.h.hive.ql.exec.FunctionRegistry - - /** - * This method is shared between UDFRegistry and UDAFRegistry. methodName will - * be "evaluate" for UDFRegistry, and "aggregate"/"evaluate"/"evaluatePartial" - * for UDAFRegistry. - * @throws UDFArgumentException - */ - public static Method getMethodInternal(Class udfClass, - String methodName, boolean exact, List argumentClasses) - throws UDFArgumentException { - - List mlist = new ArrayList<>(); - - for (Method m : udfClass.getMethods()) { - if (m.getName().equals(methodName)) { - mlist.add(m); - } - } - - return getMethodInternal(udfClass, mlist, exact, argumentClasses); - } - - /** - * Gets the closest matching method corresponding to the argument list from a - * list of methods. - * - * @param mlist - * The list of methods to inspect. - * @param exact - * Boolean to indicate whether this is an exact match or not. - * @param argumentsPassed - * The classes for the argument. - * @return The matching method. - */ - public static Method getMethodInternal(Class udfClass, List mlist, boolean exact, - List argumentsPassed) throws UDFArgumentException { - - // result - List udfMethods = new ArrayList<>(); - // The cost of the result - int leastConversionCost = Integer.MAX_VALUE; - - for (Method m : mlist) { - List argumentsAccepted = TypeInfoUtils.getParameterTypeInfos(m, - argumentsPassed.size()); - if (argumentsAccepted == null) { - // null means the method does not accept number of arguments passed. - continue; - } - - boolean match = (argumentsAccepted.size() == argumentsPassed.size()); - int conversionCost = 0; - - for (int i = 0; i < argumentsPassed.size() && match; i++) { - int cost = matchCost(argumentsPassed.get(i), argumentsAccepted.get(i), exact); - if (cost == -1) { - match = false; - } else { - conversionCost += cost; - } - } - if (LOG.isDebugEnabled()) { - LOG.debug("Method {} match: passed = {} accepted = {} method = {}", - match ? "did" : "didn't", argumentsPassed, argumentsAccepted, m); - } - if (match) { - // Always choose the function with least implicit conversions. - if (conversionCost < leastConversionCost) { - udfMethods.clear(); - udfMethods.add(m); - leastConversionCost = conversionCost; - // Found an exact match - if (leastConversionCost == 0) { - break; - } - } else if (conversionCost == leastConversionCost) { - // Ambiguous call: two methods with the same number of implicit - // conversions - udfMethods.add(m); - // Don't break! We might find a better match later. - } else { - // do nothing if implicitConversions > leastImplicitConversions - } - } - } - - if (udfMethods.size() == 0) { - // No matching methods found - throw new NoMatchingMethodException(udfClass, argumentsPassed, mlist); - } - - if (udfMethods.size() > 1) { - // First try selecting methods based on the type affinity of the arguments passed - // to the candidate method arguments. - filterMethodsByTypeAffinity(udfMethods, argumentsPassed); - } - - if (udfMethods.size() > 1) { - - // if the only difference is numeric types, pick the method - // with the smallest overall numeric type. - int lowestNumericType = Integer.MAX_VALUE; - boolean multiple = true; - Method candidate = null; - List referenceArguments = null; - - for (Method m: udfMethods) { - int maxNumericType = 0; - - List argumentsAccepted = - TypeInfoUtils.getParameterTypeInfos(m, argumentsPassed.size()); - - if (referenceArguments == null) { - // keep the arguments for reference - we want all the non-numeric - // arguments to be the same - referenceArguments = argumentsAccepted; - } - - Iterator referenceIterator = referenceArguments.iterator(); - - for (TypeInfo accepted: argumentsAccepted) { - TypeInfo reference = referenceIterator.next(); - - boolean acceptedIsPrimitive = false; - PrimitiveCategory acceptedPrimCat = PrimitiveCategory.UNKNOWN; - if (accepted.getCategory() == Category.PRIMITIVE) { - acceptedIsPrimitive = true; - acceptedPrimCat = ((PrimitiveTypeInfo) accepted).getPrimitiveCategory(); - } - if (acceptedIsPrimitive && TypeInfoUtils.numericTypes.containsKey(acceptedPrimCat)) { - // We're looking for the udf with the smallest maximum numeric type. - int typeValue = TypeInfoUtils.numericTypes.get(acceptedPrimCat); - maxNumericType = typeValue > maxNumericType ? typeValue : maxNumericType; - } else if (!accepted.equals(reference)) { - // There are non-numeric arguments that don't match from one UDF to - // another. We give up at this point. - throw new AmbiguousMethodException(udfClass, argumentsPassed, mlist); - } - } - - if (lowestNumericType > maxNumericType) { - multiple = false; - lowestNumericType = maxNumericType; - candidate = m; - } else if (maxNumericType == lowestNumericType) { - // multiple udfs with the same max type. Unless we find a lower one - // we'll give up. - multiple = true; - } - } - - if (!multiple) { - return candidate; - } else { - throw new AmbiguousMethodException(udfClass, argumentsPassed, mlist); - } - } - return udfMethods.get(0); - } - - /** - * Returns -1 if passed does not match accepted. Otherwise return the cost - * (usually 0 for no conversion and 1 for conversion). - */ - public static int matchCost(TypeInfo argumentPassed, - TypeInfo argumentAccepted, boolean exact) { - if (argumentAccepted.equals(argumentPassed) - || TypeInfoUtils.doPrimitiveCategoriesMatch(argumentPassed, argumentAccepted)) { - // matches - return 0; - } - if (argumentPassed.equals(TypeInfoFactory.voidTypeInfo)) { - // passing null matches everything - return 0; - } - if (argumentPassed.getCategory().equals(Category.LIST) - && argumentAccepted.getCategory().equals(Category.LIST)) { - // lists are compatible if and only-if the elements are compatible - TypeInfo argumentPassedElement = ((ListTypeInfo) argumentPassed) - .getListElementTypeInfo(); - TypeInfo argumentAcceptedElement = ((ListTypeInfo) argumentAccepted) - .getListElementTypeInfo(); - return matchCost(argumentPassedElement, argumentAcceptedElement, exact); - } - if (argumentPassed.getCategory().equals(Category.MAP) - && argumentAccepted.getCategory().equals(Category.MAP)) { - // lists are compatible if and only-if the elements are compatible - TypeInfo argumentPassedKey = ((MapTypeInfo) argumentPassed) - .getMapKeyTypeInfo(); - TypeInfo argumentAcceptedKey = ((MapTypeInfo) argumentAccepted) - .getMapKeyTypeInfo(); - TypeInfo argumentPassedValue = ((MapTypeInfo) argumentPassed) - .getMapValueTypeInfo(); - TypeInfo argumentAcceptedValue = ((MapTypeInfo) argumentAccepted) - .getMapValueTypeInfo(); - int cost1 = matchCost(argumentPassedKey, argumentAcceptedKey, exact); - int cost2 = matchCost(argumentPassedValue, argumentAcceptedValue, exact); - if (cost1 < 0 || cost2 < 0) { - return -1; - } - return Math.max(cost1, cost2); - } - - if (argumentAccepted.equals(TypeInfoFactory.unknownTypeInfo)) { - // accepting Object means accepting everything, - // but there is a conversion cost. - return 1; - } - if (!exact && TypeInfoUtils.implicitConvertible(argumentPassed, argumentAccepted)) { - return 1; - } - - return -1; - } - - /** - * Given a set of candidate methods and list of argument types, try to - * select the best candidate based on how close the passed argument types are - * to the candidate argument types. - * For a varchar argument, we would prefer evaluate(string) over evaluate(double). - * @param udfMethods list of candidate methods - * @param argumentsPassed list of argument types to match to the candidate methods - */ - static void filterMethodsByTypeAffinity(List udfMethods, List argumentsPassed) { - if (udfMethods.size() > 1) { - // Prefer methods with a closer signature based on the primitive grouping of each argument. - // Score each method based on its similarity to the passed argument types. - int currentScore = 0; - int bestMatchScore = 0; - Method bestMatch = null; - for (Method m: udfMethods) { - currentScore = 0; - List argumentsAccepted = - TypeInfoUtils.getParameterTypeInfos(m, argumentsPassed.size()); - Iterator argsPassedIter = argumentsPassed.iterator(); - for (TypeInfo acceptedType : argumentsAccepted) { - // Check the affinity of the argument passed in with the accepted argument, - // based on the PrimitiveGrouping - TypeInfo passedType = argsPassedIter.next(); - if (acceptedType.getCategory() == Category.PRIMITIVE - && passedType.getCategory() == Category.PRIMITIVE) { - PrimitiveGrouping acceptedPg = PrimitiveObjectInspectorUtils.getPrimitiveGrouping( - ((PrimitiveTypeInfo) acceptedType).getPrimitiveCategory()); - PrimitiveGrouping passedPg = PrimitiveObjectInspectorUtils.getPrimitiveGrouping( - ((PrimitiveTypeInfo) passedType).getPrimitiveCategory()); - if (acceptedPg == passedPg) { - // The passed argument matches somewhat closely with an accepted argument - ++currentScore; - } - } - } - // Check if the score for this method is any better relative to others - if (currentScore > bestMatchScore) { - bestMatchScore = currentScore; - bestMatch = m; - } else if (currentScore == bestMatchScore) { - bestMatch = null; // no longer a best match if more than one. - } - } - - if (bestMatch != null) { - // Found a best match during this processing, use it. - udfMethods.clear(); - udfMethods.add(bestMatch); - } - } + return HiveFunctionRegistryUtils.getMethodInternal(udfClass, "evaluate", false, argClasses); } } diff --git a/sql/hive/src/main/java/org/apache/hadoop/hive/ql/udf/generic/SparkGenericUDAFBridge.java b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/udf/generic/SparkGenericUDAFBridge.java new file mode 100644 index 0000000000000..86a6288e1007a --- /dev/null +++ b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/udf/generic/SparkGenericUDAFBridge.java @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.udf.generic; + +import java.io.Serializable; +import java.lang.reflect.Method; +import java.lang.reflect.Type; +import java.util.Arrays; + +import org.apache.hadoop.hive.ql.exec.*; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.parse.SemanticException; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils.ConversionHelper; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.ObjectInspectorOptions; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; +import org.apache.hadoop.util.ReflectionUtils; + +/** + * A equivalent implementation of {@link GenericUDAFBridge}, but eliminate calls + * of {@link FunctionRegistry} to avoid initializing Hive built-in UDFs. + *

+ * The code is based on Hive 2.3.10. + */ +@SuppressWarnings("deprecation") +public class SparkGenericUDAFBridge extends AbstractGenericUDAFResolver { + + UDAF udaf; + + public SparkGenericUDAFBridge(UDAF udaf) { + this.udaf = udaf; + } + + public Class getUDAFClass() { + return udaf.getClass(); + } + + @Override + public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters) throws SemanticException { + + UDAFEvaluatorResolver resolver = udaf.getResolver(); + if (resolver instanceof DefaultUDAFEvaluatorResolver) { + resolver = new SparkDefaultUDAFEvaluatorResolver((DefaultUDAFEvaluatorResolver) resolver); + } + Class udafEvaluatorClass = + resolver.getEvaluatorClass(Arrays.asList(parameters)); + + return new GenericUDAFBridgeEvaluator(udafEvaluatorClass); + } + + /** + * GenericUDAFBridgeEvaluator. + */ + public static class GenericUDAFBridgeEvaluator extends GenericUDAFEvaluator + implements Serializable { + + private static final long serialVersionUID = 1L; + + // Used by serialization only + public GenericUDAFBridgeEvaluator() { + } + + public Class getUdafEvaluator() { + return udafEvaluator; + } + + public void setUdafEvaluator(Class udafEvaluator) { + this.udafEvaluator = udafEvaluator; + } + + public GenericUDAFBridgeEvaluator( + Class udafEvaluator) { + this.udafEvaluator = udafEvaluator; + } + + Class udafEvaluator; + + transient ObjectInspector[] parameterOIs; + transient Object result; + + transient Method iterateMethod; + transient Method mergeMethod; + transient Method terminatePartialMethod; + transient Method terminateMethod; + + transient ConversionHelper conversionHelper; + + @Override + public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveException { + super.init(m, parameters); + parameterOIs = parameters; + + // Get the reflection methods from ue + for (Method method : udafEvaluator.getMethods()) { + method.setAccessible(true); + if (method.getName().equals("iterate")) { + iterateMethod = method; + } + if (method.getName().equals("merge")) { + mergeMethod = method; + } + if (method.getName().equals("terminatePartial")) { + terminatePartialMethod = method; + } + if (method.getName().equals("terminate")) { + terminateMethod = method; + } + } + + // Input: do Java/Writable conversion if needed + Method aggregateMethod = null; + if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) { + aggregateMethod = iterateMethod; + } else { + aggregateMethod = mergeMethod; + } + conversionHelper = new ConversionHelper(aggregateMethod, parameters); + + // Output: get the evaluate method + Method evaluateMethod = null; + if (mode == Mode.PARTIAL1 || mode == Mode.PARTIAL2) { + evaluateMethod = terminatePartialMethod; + } else { + evaluateMethod = terminateMethod; + } + // Get the output ObjectInspector from the return type. + Type returnType = evaluateMethod.getGenericReturnType(); + try { + return ObjectInspectorFactory.getReflectionObjectInspector(returnType, + ObjectInspectorOptions.JAVA); + } catch (RuntimeException e) { + throw new HiveException("Cannot recognize return type " + returnType + + " from " + evaluateMethod, e); + } + } + + /** class for storing UDAFEvaluator value. */ + static class UDAFAgg extends AbstractAggregationBuffer { + UDAFEvaluator ueObject; + + UDAFAgg(UDAFEvaluator ueObject) { + this.ueObject = ueObject; + } + } + + @Override + public AggregationBuffer getNewAggregationBuffer() { + return new UDAFAgg(ReflectionUtils.newInstance(udafEvaluator, null)); + } + + @Override + public void reset(AggregationBuffer agg) throws HiveException { + ((UDAFAgg) agg).ueObject.init(); + } + + @Override + public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException { + HiveFunctionRegistryUtils.invoke(iterateMethod, ((UDAFAgg) agg).ueObject, + conversionHelper.convertIfNecessary(parameters)); + } + + @Override + public void merge(AggregationBuffer agg, Object partial) throws HiveException { + HiveFunctionRegistryUtils.invoke(mergeMethod, ((UDAFAgg) agg).ueObject, + conversionHelper.convertIfNecessary(partial)); + } + + @Override + public Object terminate(AggregationBuffer agg) throws HiveException { + return HiveFunctionRegistryUtils.invoke(terminateMethod, ((UDAFAgg) agg).ueObject); + } + + @Override + public Object terminatePartial(AggregationBuffer agg) throws HiveException { + return HiveFunctionRegistryUtils.invoke(terminatePartialMethod, ((UDAFAgg) agg).ueObject); + } + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFEvaluators.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFEvaluators.scala index 26fac1b155617..6979d89d6b230 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFEvaluators.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFEvaluators.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.hive -import java.lang.reflect.{InvocationTargetException, Method} +import java.lang.reflect.Method import scala.jdk.CollectionConverters._ -import org.apache.hadoop.hive.ql.exec.{DefaultUDFMethodResolver, SparkDefaultUDFMethodResolver, UDF, UDFArgumentException} +import org.apache.hadoop.hive.ql.exec.{DefaultUDFMethodResolver, HiveFunctionRegistryUtils, SparkDefaultUDFMethodResolver, UDF, UDFArgumentException} import org.apache.hadoop.hive.ql.metadata.HiveException import org.apache.hadoop.hive.ql.udf.{UDFType => HiveUDFType} import org.apache.hadoop.hive.ql.udf.generic.GenericUDF @@ -103,22 +103,10 @@ class HiveSimpleUDFEvaluator( method.getGenericReturnType, ObjectInspectorOptions.JAVA)) override def doEvaluate(): Any = { - val arguments = conversionHelper.convertIfNecessary(inputs: _*) - // Follow behavior of o.a.h.hive.ql.exec.FunctionRegistry#invoke - val ret = try { - method.invoke(function, arguments: _*) - } catch { - case e: Exception => - val argumentString = - if (arguments == null) "null" else arguments.mkString("{", ",", "}") - val detailedMsg = if (e.isInstanceOf[InvocationTargetException]) { - e.getCause.getMessage - } else { - e.getMessage - } - throw new HiveException( - s"Unable to execute method $method with arguments $argumentString:$detailedMsg", e) - } + val ret = HiveFunctionRegistryUtils.invoke( + method, + function, + conversionHelper.convertIfNecessary(inputs: _*): _*) unwrapper(ret) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index 227c6a618e3d4..bf708eecf0c0c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -353,7 +353,7 @@ private[hive] case class HiveUDAFFunction( private def newEvaluator(): GenericUDAFEvaluator = { val resolver = if (isUDAFBridgeRequired) { - new GenericUDAFBridge(funcWrapper.createFunction[UDAF]()) + new SparkGenericUDAFBridge(funcWrapper.createFunction[UDAF]()) } else { funcWrapper.createFunction[AbstractGenericUDAFResolver]() } From 68c571889677bfe3e8a87325a26deeac0dcf4ec6 Mon Sep 17 00:00:00 2001 From: Cheng Pan Date: Tue, 11 Mar 2025 21:02:00 +0800 Subject: [PATCH 3/6] nit --- sql/hive-thriftserver/pom.xml | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index 4d80247666e2a..135b84cd01f85 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -148,16 +148,6 @@ byte-buddy-agent test - - - - - - - - - - net.sf.jpam jpam From 9506f9929c78e42f0d2d2a4323ccb77eb8d8887e Mon Sep 17 00:00:00 2001 From: Cheng Pan Date: Wed, 12 Mar 2025 14:52:59 +0800 Subject: [PATCH 4/6] simplify SparkGenericUDAFBridge --- project/SparkBuild.scala | 10 ++ .../udf/generic/SparkGenericUDAFBridge.java | 131 +++--------------- 2 files changed, 27 insertions(+), 114 deletions(-) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 7ea894e5efcaa..fafcacbbdd237 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -412,6 +412,8 @@ object SparkBuild extends PomBuild { /* Hive console settings */ enable(Hive.settings)(hive) + enable(HiveThriftServer.settings)(hiveThriftServer) + enable(SparkConnectCommon.settings)(connectCommon) enable(SparkConnect.settings)(connect) enable(SparkConnectClient.settings)(connectClient) @@ -1203,6 +1205,14 @@ object Hive { ) } +object HiveThriftServer { + lazy val settings = Seq( + excludeDependencies ++= Seq( + ExclusionRule("org.apache.hive", "hive-llap-common"), + ExclusionRule("org.apache.hive", "hive-llap-client")) + ) +} + object YARN { val genConfigProperties = TaskKey[Unit]("gen-config-properties", "Generate config.properties which contains a setting whether Hadoop is provided or not") diff --git a/sql/hive/src/main/java/org/apache/hadoop/hive/ql/udf/generic/SparkGenericUDAFBridge.java b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/udf/generic/SparkGenericUDAFBridge.java index 86a6288e1007a..72a758de90a26 100644 --- a/sql/hive/src/main/java/org/apache/hadoop/hive/ql/udf/generic/SparkGenericUDAFBridge.java +++ b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/udf/generic/SparkGenericUDAFBridge.java @@ -17,20 +17,19 @@ package org.apache.hadoop.hive.ql.udf.generic; -import java.io.Serializable; -import java.lang.reflect.Method; -import java.lang.reflect.Type; +import java.io.Serial; import java.util.Arrays; -import org.apache.hadoop.hive.ql.exec.*; +import org.apache.hadoop.hive.ql.exec.DefaultUDAFEvaluatorResolver; +import org.apache.hadoop.hive.ql.exec.FunctionRegistry; +import org.apache.hadoop.hive.ql.exec.HiveFunctionRegistryUtils; +import org.apache.hadoop.hive.ql.exec.SparkDefaultUDAFEvaluatorResolver; +import org.apache.hadoop.hive.ql.exec.UDAF; +import org.apache.hadoop.hive.ql.exec.UDAFEvaluator; +import org.apache.hadoop.hive.ql.exec.UDAFEvaluatorResolver; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.parse.SemanticException; -import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils.ConversionHelper; -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.ObjectInspectorOptions; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; -import org.apache.hadoop.util.ReflectionUtils; /** * A equivalent implementation of {@link GenericUDAFBridge}, but eliminate calls @@ -39,16 +38,10 @@ * The code is based on Hive 2.3.10. */ @SuppressWarnings("deprecation") -public class SparkGenericUDAFBridge extends AbstractGenericUDAFResolver { - - UDAF udaf; +public class SparkGenericUDAFBridge extends GenericUDAFBridge { public SparkGenericUDAFBridge(UDAF udaf) { - this.udaf = udaf; - } - - public Class getUDAFClass() { - return udaf.getClass(); + super(udaf); } @Override @@ -61,112 +54,22 @@ public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters) throws SemanticE Class udafEvaluatorClass = resolver.getEvaluatorClass(Arrays.asList(parameters)); - return new GenericUDAFBridgeEvaluator(udafEvaluatorClass); + return new SparkGenericUDAFBridgeEvaluator(udafEvaluatorClass); } - /** - * GenericUDAFBridgeEvaluator. - */ - public static class GenericUDAFBridgeEvaluator extends GenericUDAFEvaluator - implements Serializable { + public static class SparkGenericUDAFBridgeEvaluator + extends GenericUDAFBridge.GenericUDAFBridgeEvaluator { + @Serial private static final long serialVersionUID = 1L; // Used by serialization only - public GenericUDAFBridgeEvaluator() { - } - - public Class getUdafEvaluator() { - return udafEvaluator; - } - - public void setUdafEvaluator(Class udafEvaluator) { - this.udafEvaluator = udafEvaluator; + public SparkGenericUDAFBridgeEvaluator() { } - public GenericUDAFBridgeEvaluator( + public SparkGenericUDAFBridgeEvaluator( Class udafEvaluator) { - this.udafEvaluator = udafEvaluator; - } - - Class udafEvaluator; - - transient ObjectInspector[] parameterOIs; - transient Object result; - - transient Method iterateMethod; - transient Method mergeMethod; - transient Method terminatePartialMethod; - transient Method terminateMethod; - - transient ConversionHelper conversionHelper; - - @Override - public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveException { - super.init(m, parameters); - parameterOIs = parameters; - - // Get the reflection methods from ue - for (Method method : udafEvaluator.getMethods()) { - method.setAccessible(true); - if (method.getName().equals("iterate")) { - iterateMethod = method; - } - if (method.getName().equals("merge")) { - mergeMethod = method; - } - if (method.getName().equals("terminatePartial")) { - terminatePartialMethod = method; - } - if (method.getName().equals("terminate")) { - terminateMethod = method; - } - } - - // Input: do Java/Writable conversion if needed - Method aggregateMethod = null; - if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) { - aggregateMethod = iterateMethod; - } else { - aggregateMethod = mergeMethod; - } - conversionHelper = new ConversionHelper(aggregateMethod, parameters); - - // Output: get the evaluate method - Method evaluateMethod = null; - if (mode == Mode.PARTIAL1 || mode == Mode.PARTIAL2) { - evaluateMethod = terminatePartialMethod; - } else { - evaluateMethod = terminateMethod; - } - // Get the output ObjectInspector from the return type. - Type returnType = evaluateMethod.getGenericReturnType(); - try { - return ObjectInspectorFactory.getReflectionObjectInspector(returnType, - ObjectInspectorOptions.JAVA); - } catch (RuntimeException e) { - throw new HiveException("Cannot recognize return type " + returnType - + " from " + evaluateMethod, e); - } - } - - /** class for storing UDAFEvaluator value. */ - static class UDAFAgg extends AbstractAggregationBuffer { - UDAFEvaluator ueObject; - - UDAFAgg(UDAFEvaluator ueObject) { - this.ueObject = ueObject; - } - } - - @Override - public AggregationBuffer getNewAggregationBuffer() { - return new UDAFAgg(ReflectionUtils.newInstance(udafEvaluator, null)); - } - - @Override - public void reset(AggregationBuffer agg) throws HiveException { - ((UDAFAgg) agg).ueObject.init(); + super(udafEvaluator); } @Override From 55feae10ce0bd05e34ae0bce90f7c52b4c051914 Mon Sep 17 00:00:00 2001 From: Cheng Pan Date: Wed, 12 Mar 2025 14:54:11 +0800 Subject: [PATCH 5/6] typo --- .../scala/org/apache/spark/sql/hive/hiveUDFEvaluators.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFEvaluators.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFEvaluators.scala index 6979d89d6b230..866d88dee8783 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFEvaluators.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFEvaluators.scala @@ -121,7 +121,7 @@ class HiveGenericUDFEvaluator( @transient lazy val returnInspector = { // Inline o.a.h.hive.ql.udf.generic.GenericUDF#initializeAndFoldConstants, but - // elminate calls o.a.h.hive.ql.exec.FunctionRegistry to avoid initializing Hive + // eliminate calls o.a.h.hive.ql.exec.FunctionRegistry to avoid initializing Hive // built-in UDFs. val oi = function.initialize(argumentInspectors) // If the UDF depends on any external resources, we can't fold because the From dd80f9c993fb4c24b938b0b06f7d7a730bf2226e Mon Sep 17 00:00:00 2001 From: Cheng Pan Date: Wed, 12 Mar 2025 15:35:22 +0800 Subject: [PATCH 6/6] nit --- .../hadoop/hive/ql/exec/HiveFunctionRegistryUtils.java | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sql/hive/src/main/java/org/apache/hadoop/hive/ql/exec/HiveFunctionRegistryUtils.java b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/exec/HiveFunctionRegistryUtils.java index 9e3bf24977ecd..333ae0151a5a6 100644 --- a/sql/hive/src/main/java/org/apache/hadoop/hive/ql/exec/HiveFunctionRegistryUtils.java +++ b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/exec/HiveFunctionRegistryUtils.java @@ -41,7 +41,7 @@ public class HiveFunctionRegistryUtils { public static final SparkLogger LOG = - SparkLoggerFactory.getLogger(HiveFunctionRegistryUtils.class); + SparkLoggerFactory.getLogger(HiveFunctionRegistryUtils.class); /** * This method is shared between UDFRegistry and UDAFRegistry. methodName will @@ -103,10 +103,10 @@ public static Method getMethodInternal(Class udfClass, List mlist, bo conversionCost += cost; } } - if (LOG.isDebugEnabled()) { - LOG.debug("Method {} match: passed = {} accepted = {} method = {}", - match ? "did" : "didn't", argumentsPassed, argumentsAccepted, m); - } + + LOG.debug("Method {} match: passed = {} accepted = {} method = {}", + match ? "did" : "didn't", argumentsPassed, argumentsAccepted, m); + if (match) { // Always choose the function with least implicit conversions. if (conversionCost < leastConversionCost) {