Skip to content

Commit baef6ec

Browse files
committed
Merge remote-tracking branch 'origin/master' into execution
2 parents 6c7d259 + 4bb6a53 commit baef6ec

File tree

7 files changed

+93
-21
lines changed

7 files changed

+93
-21
lines changed

core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1277,10 +1277,10 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou
12771277
*/
12781278
test("don't submit stage until its dependencies map outputs are registered (SPARK-5259)") {
12791279
val firstRDD = new MyRDD(sc, 3, Nil)
1280-
val firstShuffleDep = new ShuffleDependency(firstRDD, new HashPartitioner(2))
1280+
val firstShuffleDep = new ShuffleDependency(firstRDD, new HashPartitioner(3))
12811281
val firstShuffleId = firstShuffleDep.shuffleId
12821282
val shuffleMapRdd = new MyRDD(sc, 3, List(firstShuffleDep))
1283-
val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2))
1283+
val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(1))
12841284
val reduceRdd = new MyRDD(sc, 1, List(shuffleDep))
12851285
submit(reduceRdd, Array(0))
12861286

@@ -1583,7 +1583,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou
15831583
*/
15841584
test("run trivial shuffle with out-of-band executor failure and retry") {
15851585
val shuffleMapRdd = new MyRDD(sc, 2, Nil)
1586-
val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2))
1586+
val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(1))
15871587
val shuffleId = shuffleDep.shuffleId
15881588
val reduceRdd = new MyRDD(sc, 1, List(shuffleDep), tracker = mapOutputTracker)
15891589
submit(reduceRdd, Array(0))
@@ -1791,7 +1791,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou
17911791
test("reduce tasks should be placed locally with map output") {
17921792
// Create a shuffleMapRdd with 1 partition
17931793
val shuffleMapRdd = new MyRDD(sc, 1, Nil)
1794-
val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2))
1794+
val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(1))
17951795
val shuffleId = shuffleDep.shuffleId
17961796
val reduceRdd = new MyRDD(sc, 1, List(shuffleDep), tracker = mapOutputTracker)
17971797
submit(reduceRdd, Array(0))

python/pyspark/ml/feature.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3043,26 +3043,35 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, JavaMLReadable, JavaM
30433043
"Force to index label whether it is numeric or string",
30443044
typeConverter=TypeConverters.toBoolean)
30453045

3046+
stringIndexerOrderType = Param(Params._dummy(), "stringIndexerOrderType",
3047+
"How to order categories of a string feature column used by " +
3048+
"StringIndexer. The last category after ordering is dropped " +
3049+
"when encoding strings. Supported options: frequencyDesc, " +
3050+
"frequencyAsc, alphabetDesc, alphabetAsc. The default value " +
3051+
"is frequencyDesc. When the ordering is set to alphabetDesc, " +
3052+
"RFormula drops the same category as R when encoding strings.",
3053+
typeConverter=TypeConverters.toString)
3054+
30463055
@keyword_only
30473056
def __init__(self, formula=None, featuresCol="features", labelCol="label",
3048-
forceIndexLabel=False):
3057+
forceIndexLabel=False, stringIndexerOrderType="frequencyDesc"):
30493058
"""
30503059
__init__(self, formula=None, featuresCol="features", labelCol="label", \
3051-
forceIndexLabel=False)
3060+
forceIndexLabel=False, stringIndexerOrderType="frequencyDesc")
30523061
"""
30533062
super(RFormula, self).__init__()
30543063
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.RFormula", self.uid)
3055-
self._setDefault(forceIndexLabel=False)
3064+
self._setDefault(forceIndexLabel=False, stringIndexerOrderType="frequencyDesc")
30563065
kwargs = self._input_kwargs
30573066
self.setParams(**kwargs)
30583067

30593068
@keyword_only
30603069
@since("1.5.0")
30613070
def setParams(self, formula=None, featuresCol="features", labelCol="label",
3062-
forceIndexLabel=False):
3071+
forceIndexLabel=False, stringIndexerOrderType="frequencyDesc"):
30633072
"""
30643073
setParams(self, formula=None, featuresCol="features", labelCol="label", \
3065-
forceIndexLabel=False)
3074+
forceIndexLabel=False, stringIndexerOrderType="frequencyDesc")
30663075
Sets params for RFormula.
30673076
"""
30683077
kwargs = self._input_kwargs
@@ -3096,6 +3105,20 @@ def getForceIndexLabel(self):
30963105
"""
30973106
return self.getOrDefault(self.forceIndexLabel)
30983107

3108+
@since("2.3.0")
3109+
def setStringIndexerOrderType(self, value):
3110+
"""
3111+
Sets the value of :py:attr:`stringIndexerOrderType`.
3112+
"""
3113+
return self._set(stringIndexerOrderType=value)
3114+
3115+
@since("2.3.0")
3116+
def getStringIndexerOrderType(self):
3117+
"""
3118+
Gets the value of :py:attr:`stringIndexerOrderType` or its default value 'frequencyDesc'.
3119+
"""
3120+
return self.getOrDefault(self.stringIndexerOrderType)
3121+
30993122
def _create_model(self, java_model):
31003123
return RFormulaModel(java_model)
31013124

python/pyspark/ml/tests.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -538,6 +538,19 @@ def test_rformula_force_index_label(self):
538538
transformedDF2 = model2.transform(df)
539539
self.assertEqual(transformedDF2.head().label, 0.0)
540540

541+
def test_rformula_string_indexer_order_type(self):
542+
df = self.spark.createDataFrame([
543+
(1.0, 1.0, "a"),
544+
(0.0, 2.0, "b"),
545+
(1.0, 0.0, "a")], ["y", "x", "s"])
546+
rf = RFormula(formula="y ~ x + s", stringIndexerOrderType="alphabetDesc")
547+
self.assertEqual(rf.getStringIndexerOrderType(), 'alphabetDesc')
548+
transformedDF = rf.fit(df).transform(df)
549+
observed = transformedDF.select("features").collect()
550+
expected = [[1.0, 0.0], [2.0, 1.0], [0.0, 0.0]]
551+
for i in range(0, len(expected)):
552+
self.assertTrue(all(observed[i]["features"].toArray() == expected[i]))
553+
541554

542555
class HasInducedError(Params):
543556

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1105,8 +1105,9 @@ class SessionCatalog(
11051105
!hiveFunctions.contains(name.funcName.toLowerCase(Locale.ROOT))
11061106
}
11071107

1108-
protected def failFunctionLookup(name: String): Nothing = {
1109-
throw new NoSuchFunctionException(db = currentDb, func = name)
1108+
protected def failFunctionLookup(name: FunctionIdentifier): Nothing = {
1109+
throw new NoSuchFunctionException(
1110+
db = name.database.getOrElse(getCurrentDatabase), func = name.funcName)
11101111
}
11111112

11121113
/**
@@ -1128,7 +1129,7 @@ class SessionCatalog(
11281129
qualifiedName.database.orNull,
11291130
qualifiedName.identifier)
11301131
} else {
1131-
failFunctionLookup(name.funcName)
1132+
failFunctionLookup(name)
11321133
}
11331134
}
11341135
}
@@ -1158,8 +1159,8 @@ class SessionCatalog(
11581159
}
11591160

11601161
// If the name itself is not qualified, add the current database to it.
1161-
val database = name.database.orElse(Some(currentDb)).map(formatDatabaseName)
1162-
val qualifiedName = name.copy(database = database)
1162+
val database = formatDatabaseName(name.database.getOrElse(getCurrentDatabase))
1163+
val qualifiedName = name.copy(database = Some(database))
11631164

11641165
if (functionRegistry.functionExists(qualifiedName.unquotedString)) {
11651166
// This function has been already loaded into the function registry.
@@ -1172,10 +1173,10 @@ class SessionCatalog(
11721173
// in the metastore). We need to first put the function in the FunctionRegistry.
11731174
// TODO: why not just check whether the function exists first?
11741175
val catalogFunction = try {
1175-
externalCatalog.getFunction(currentDb, name.funcName)
1176+
externalCatalog.getFunction(database, name.funcName)
11761177
} catch {
1177-
case e: AnalysisException => failFunctionLookup(name.funcName)
1178-
case e: NoSuchPermanentFunctionException => failFunctionLookup(name.funcName)
1178+
case _: AnalysisException => failFunctionLookup(name)
1179+
case _: NoSuchPermanentFunctionException => failFunctionLookup(name)
11791180
}
11801181
loadFunctionResources(catalogFunction.resources)
11811182
// Please note that qualifiedName is provided by the user. However,

sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2780,7 +2780,22 @@ class Dataset[T] private[sql](
27802780
createTempViewCommand(viewName, replace = false, global = true)
27812781
}
27822782

2783-
private[sql] def createTempViewCommand(
2783+
/**
2784+
* Creates or replaces a global temporary view using the given name. The lifetime of this
2785+
* temporary view is tied to this Spark application.
2786+
*
2787+
* Global temporary view is cross-session. Its lifetime is the lifetime of the Spark application,
2788+
* i.e. it will be automatically dropped when the application terminates. It's tied to a system
2789+
* preserved database `_global_temp`, and we must use the qualified name to refer a global temp
2790+
* view, e.g. `SELECT * FROM _global_temp.view1`.
2791+
*
2792+
* @group basic
2793+
*/
2794+
def createOrReplaceGlobalTempView(viewName: String): Unit = withPlan {
2795+
createTempViewCommand(viewName, replace = true, global = true)
2796+
}
2797+
2798+
private[spark] def createTempViewCommand(
27842799
viewName: String,
27852800
replace: Boolean,
27862801
global: Boolean): CreateViewCommand = {

sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -140,20 +140,20 @@ private[sql] class HiveSessionCatalog(
140140
// Hive is case insensitive.
141141
val functionName = funcName.unquotedString.toLowerCase(Locale.ROOT)
142142
if (!hiveFunctions.contains(functionName)) {
143-
failFunctionLookup(funcName.unquotedString)
143+
failFunctionLookup(funcName)
144144
}
145145

146146
// TODO: Remove this fallback path once we implement the list of fallback functions
147147
// defined below in hiveFunctions.
148148
val functionInfo = {
149149
try {
150150
Option(HiveFunctionRegistry.getFunctionInfo(functionName)).getOrElse(
151-
failFunctionLookup(funcName.unquotedString))
151+
failFunctionLookup(funcName))
152152
} catch {
153153
// If HiveFunctionRegistry.getFunctionInfo throws an exception,
154154
// we are failing to load a Hive builtin function, which means that
155155
// the given function is not a Hive builtin function.
156-
case NonFatal(e) => failFunctionLookup(funcName.unquotedString)
156+
case NonFatal(e) => failFunctionLookup(funcName)
157157
}
158158
}
159159
val className = functionInfo.getFunctionClass.getName

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
3434
import org.apache.spark.sql.catalyst.plans.logical.Project
3535
import org.apache.spark.sql.functions.max
3636
import org.apache.spark.sql.hive.test.TestHiveSingleton
37+
import org.apache.spark.sql.internal.SQLConf
3738
import org.apache.spark.sql.test.SQLTestUtils
3839
import org.apache.spark.util.Utils
3940

@@ -590,6 +591,25 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils {
590591
}
591592
}
592593
}
594+
595+
test("Call the function registered in the not-current database") {
596+
Seq("true", "false").foreach { caseSensitive =>
597+
withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive) {
598+
withDatabase("dAtABaSe1") {
599+
sql("CREATE DATABASE dAtABaSe1")
600+
withUserDefinedFunction("dAtABaSe1.test_avg" -> false) {
601+
sql(s"CREATE FUNCTION dAtABaSe1.test_avg AS '${classOf[GenericUDAFAverage].getName}'")
602+
checkAnswer(sql("SELECT dAtABaSe1.test_avg(1)"), Row(1.0))
603+
}
604+
val message = intercept[AnalysisException] {
605+
sql("SELECT dAtABaSe1.unknownFunc(1)")
606+
}.getMessage
607+
assert(message.contains("Undefined function: 'unknownFunc'") &&
608+
message.contains("nor a permanent function registered in the database 'dAtABaSe1'"))
609+
}
610+
}
611+
}
612+
}
593613
}
594614

595615
class TestPair(x: Int, y: Int) extends Writable with Serializable {

0 commit comments

Comments
 (0)