Skip to content

Commit 508fee0

Browse files
mn-mikkemn-mikke
authored andcommitted
[SPARK-23821][SQL] Merging current master to the feature branch.
2 parents 37b68cd + e6b4660 commit 508fee0

File tree

29 files changed

+1689
-287
lines changed

29 files changed

+1689
-287
lines changed

core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ private[scheduler] class BlacklistTracker (
210210
updateNextExpiryTime()
211211
killBlacklistedExecutor(exec)
212212

213-
val blacklistedExecsOnNode = nodeToBlacklistedExecs.getOrElseUpdate(exec, HashSet[String]())
213+
val blacklistedExecsOnNode = nodeToBlacklistedExecs.getOrElseUpdate(host, HashSet[String]())
214214
blacklistedExecsOnNode += exec
215215
}
216216
}

core/src/test/scala/org/apache/spark/SparkContextSuite.scala

Lines changed: 26 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ package org.apache.spark
2020
import java.io.File
2121
import java.net.{MalformedURLException, URI}
2222
import java.nio.charset.StandardCharsets
23-
import java.util.concurrent.{Semaphore, TimeUnit}
23+
import java.util.concurrent.{CountDownLatch, Semaphore, TimeUnit}
2424

2525
import scala.concurrent.duration._
2626

@@ -498,45 +498,36 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu
498498

499499
test("Cancelling stages/jobs with custom reasons.") {
500500
sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local"))
501+
sc.setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, "true")
501502
val REASON = "You shall not pass"
502-
val slices = 10
503503

504-
val listener = new SparkListener {
505-
override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = {
506-
if (SparkContextSuite.cancelStage) {
507-
eventually(timeout(10.seconds)) {
508-
assert(SparkContextSuite.isTaskStarted)
504+
for (cancelWhat <- Seq("stage", "job")) {
505+
// This countdown latch used to make sure stage or job canceled in listener
506+
val latch = new CountDownLatch(1)
507+
508+
val listener = cancelWhat match {
509+
case "stage" =>
510+
new SparkListener {
511+
override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = {
512+
sc.cancelStage(taskStart.stageId, REASON)
513+
latch.countDown()
514+
}
509515
}
510-
sc.cancelStage(taskStart.stageId, REASON)
511-
SparkContextSuite.cancelStage = false
512-
SparkContextSuite.semaphore.release(slices)
513-
}
514-
}
515-
516-
override def onJobStart(jobStart: SparkListenerJobStart): Unit = {
517-
if (SparkContextSuite.cancelJob) {
518-
eventually(timeout(10.seconds)) {
519-
assert(SparkContextSuite.isTaskStarted)
516+
case "job" =>
517+
new SparkListener {
518+
override def onJobStart(jobStart: SparkListenerJobStart): Unit = {
519+
sc.cancelJob(jobStart.jobId, REASON)
520+
latch.countDown()
521+
}
520522
}
521-
sc.cancelJob(jobStart.jobId, REASON)
522-
SparkContextSuite.cancelJob = false
523-
SparkContextSuite.semaphore.release(slices)
524-
}
525523
}
526-
}
527-
sc.addSparkListener(listener)
528-
529-
for (cancelWhat <- Seq("stage", "job")) {
530-
SparkContextSuite.semaphore.drainPermits()
531-
SparkContextSuite.isTaskStarted = false
532-
SparkContextSuite.cancelStage = (cancelWhat == "stage")
533-
SparkContextSuite.cancelJob = (cancelWhat == "job")
524+
sc.addSparkListener(listener)
534525

535526
val ex = intercept[SparkException] {
536-
sc.range(0, 10000L, numSlices = slices).mapPartitions { x =>
537-
SparkContextSuite.isTaskStarted = true
538-
// Block waiting for the listener to cancel the stage or job.
539-
SparkContextSuite.semaphore.acquire()
527+
sc.range(0, 10000L, numSlices = 10).mapPartitions { x =>
528+
x.synchronized {
529+
x.wait()
530+
}
540531
x
541532
}.count()
542533
}
@@ -550,9 +541,11 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu
550541
fail("Expected the cause to be SparkException, got " + cause.toString() + " instead.")
551542
}
552543

544+
latch.await(20, TimeUnit.SECONDS)
553545
eventually(timeout(20.seconds)) {
554546
assert(sc.statusTracker.getExecutorInfos.map(_.numRunningTasks()).sum == 0)
555547
}
548+
sc.removeSparkListener(listener)
556549
}
557550
}
558551

@@ -637,8 +630,6 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu
637630
}
638631

639632
object SparkContextSuite {
640-
@volatile var cancelJob = false
641-
@volatile var cancelStage = false
642633
@volatile var isTaskStarted = false
643634
@volatile var taskKilled = false
644635
@volatile var taskSucceeded = false

core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -574,6 +574,9 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M
574574
verify(allocationClientMock, never).killExecutors(any(), any(), any(), any())
575575
verify(allocationClientMock, never).killExecutorsOnHost(any())
576576

577+
assert(blacklist.nodeToBlacklistedExecs.contains("hostA"))
578+
assert(blacklist.nodeToBlacklistedExecs("hostA").contains("1"))
579+
577580
// Enable auto-kill. Blacklist an executor and make sure killExecutors is called.
578581
conf.set(config.BLACKLIST_KILL_ENABLED, true)
579582
blacklist = new BlacklistTracker(listenerBusMock, conf, Some(allocationClientMock), clock)
@@ -589,6 +592,8 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M
589592
1000 + blacklist.BLACKLIST_TIMEOUT_MILLIS)
590593
assert(blacklist.nextExpiryTime === 1000 + blacklist.BLACKLIST_TIMEOUT_MILLIS)
591594
assert(blacklist.nodeIdToBlacklistExpiryTime.isEmpty)
595+
assert(blacklist.nodeToBlacklistedExecs.contains("hostA"))
596+
assert(blacklist.nodeToBlacklistedExecs("hostA").contains("1"))
592597

593598
// Enable external shuffle service to see if all the executors on this node will be killed.
594599
conf.set(config.SHUFFLE_SERVICE_ENABLED, true)

docs/sql-programming-guide.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1810,7 +1810,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see
18101810
- Since Spark 2.4, writing a dataframe with an empty or nested empty schema using any file formats (parquet, orc, json, text, csv etc.) is not allowed. An exception is thrown when attempting to write dataframes with empty schema.
18111811
- Since Spark 2.4, Spark compares a DATE type with a TIMESTAMP type after promotes both sides to TIMESTAMP. To set `false` to `spark.sql.hive.compareDateTimestampInTimestamp` restores the previous behavior. This option will be removed in Spark 3.0.
18121812
- Since Spark 2.4, creating a managed table with nonempty location is not allowed. An exception is thrown when attempting to create a managed table with nonempty location. To set `true` to `spark.sql.allowCreatingManagedTableUsingNonemptyLocation` restores the previous behavior. This option will be removed in Spark 3.0.
1813-
1813+
- Since Spark 2.4, the type coercion rules can automatically promote the argument types of the variadic SQL functions (e.g., IN/COALESCE) to the widest common type, no matter how the input arguments order. In prior Spark versions, the promotion could fail in some specific orders (e.g., TimestampType, IntegerType and StringType) and throw an exception.
18141814
## Upgrading From Spark SQL 2.2 to 2.3
18151815

18161816
- Since Spark 2.3, the queries from raw JSON/CSV files are disallowed when the referenced columns only include the internal corrupt record column (named `_corrupt_record` by default). For example, `spark.read.schema(schema).json(file).filter($"_corrupt_record".isNotNull).count()` and `spark.read.schema(schema).json(file).select("_corrupt_record").show()`. Instead, you can cache or save the parsed results and then send the same query. For example, `val df = spark.read.schema(schema).json(file).cache()` and then `df.filter($"_corrupt_record".isNotNull).count()`.
Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,256 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml.clustering
19+
20+
import org.apache.spark.annotation.{Experimental, Since}
21+
import org.apache.spark.ml.Transformer
22+
import org.apache.spark.ml.param._
23+
import org.apache.spark.ml.param.shared._
24+
import org.apache.spark.ml.util._
25+
import org.apache.spark.mllib.clustering.{PowerIterationClustering => MLlibPowerIterationClustering}
26+
import org.apache.spark.rdd.RDD
27+
import org.apache.spark.sql.{DataFrame, Dataset, Row}
28+
import org.apache.spark.sql.functions.col
29+
import org.apache.spark.sql.types._
30+
31+
/**
32+
* Common params for PowerIterationClustering
33+
*/
34+
private[clustering] trait PowerIterationClusteringParams extends Params with HasMaxIter
35+
with HasPredictionCol {
36+
37+
/**
38+
* The number of clusters to create (k). Must be &gt; 1. Default: 2.
39+
* @group param
40+
*/
41+
@Since("2.4.0")
42+
final val k = new IntParam(this, "k", "The number of clusters to create. " +
43+
"Must be > 1.", ParamValidators.gt(1))
44+
45+
/** @group getParam */
46+
@Since("2.4.0")
47+
def getK: Int = $(k)
48+
49+
/**
50+
* Param for the initialization algorithm. This can be either "random" to use a random vector
51+
* as vertex properties, or "degree" to use a normalized sum of similarities with other vertices.
52+
* Default: random.
53+
* @group expertParam
54+
*/
55+
@Since("2.4.0")
56+
final val initMode = {
57+
val allowedParams = ParamValidators.inArray(Array("random", "degree"))
58+
new Param[String](this, "initMode", "The initialization algorithm. This can be either " +
59+
"'random' to use a random vector as vertex properties, or 'degree' to use a normalized sum " +
60+
"of similarities with other vertices. Supported options: 'random' and 'degree'.",
61+
allowedParams)
62+
}
63+
64+
/** @group expertGetParam */
65+
@Since("2.4.0")
66+
def getInitMode: String = $(initMode)
67+
68+
/**
69+
* Param for the name of the input column for vertex IDs.
70+
* Default: "id"
71+
* @group param
72+
*/
73+
@Since("2.4.0")
74+
val idCol = new Param[String](this, "idCol", "Name of the input column for vertex IDs.",
75+
(value: String) => value.nonEmpty)
76+
77+
setDefault(idCol, "id")
78+
79+
/** @group getParam */
80+
@Since("2.4.0")
81+
def getIdCol: String = getOrDefault(idCol)
82+
83+
/**
84+
* Param for the name of the input column for neighbors in the adjacency list representation.
85+
* Default: "neighbors"
86+
* @group param
87+
*/
88+
@Since("2.4.0")
89+
val neighborsCol = new Param[String](this, "neighborsCol",
90+
"Name of the input column for neighbors in the adjacency list representation.",
91+
(value: String) => value.nonEmpty)
92+
93+
setDefault(neighborsCol, "neighbors")
94+
95+
/** @group getParam */
96+
@Since("2.4.0")
97+
def getNeighborsCol: String = $(neighborsCol)
98+
99+
/**
100+
* Param for the name of the input column for neighbors in the adjacency list representation.
101+
* Default: "similarities"
102+
* @group param
103+
*/
104+
@Since("2.4.0")
105+
val similaritiesCol = new Param[String](this, "similaritiesCol",
106+
"Name of the input column for neighbors in the adjacency list representation.",
107+
(value: String) => value.nonEmpty)
108+
109+
setDefault(similaritiesCol, "similarities")
110+
111+
/** @group getParam */
112+
@Since("2.4.0")
113+
def getSimilaritiesCol: String = $(similaritiesCol)
114+
115+
protected def validateAndTransformSchema(schema: StructType): StructType = {
116+
SchemaUtils.checkColumnTypes(schema, $(idCol), Seq(IntegerType, LongType))
117+
SchemaUtils.checkColumnTypes(schema, $(neighborsCol),
118+
Seq(ArrayType(IntegerType, containsNull = false),
119+
ArrayType(LongType, containsNull = false)))
120+
SchemaUtils.checkColumnTypes(schema, $(similaritiesCol),
121+
Seq(ArrayType(FloatType, containsNull = false),
122+
ArrayType(DoubleType, containsNull = false)))
123+
SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType)
124+
}
125+
}
126+
127+
/**
128+
* :: Experimental ::
129+
* Power Iteration Clustering (PIC), a scalable graph clustering algorithm developed by
130+
* <a href=http://www.icml2010.org/papers/387.pdf>Lin and Cohen</a>. From the abstract:
131+
* PIC finds a very low-dimensional embedding of a dataset using truncated power
132+
* iteration on a normalized pair-wise similarity matrix of the data.
133+
*
134+
* PIC takes an affinity matrix between items (or vertices) as input. An affinity matrix
135+
* is a symmetric matrix whose entries are non-negative similarities between items.
136+
* PIC takes this matrix (or graph) as an adjacency matrix. Specifically, each input row includes:
137+
* - `idCol`: vertex ID
138+
* - `neighborsCol`: neighbors of vertex in `idCol`
139+
* - `similaritiesCol`: non-negative weights (similarities) of edges between the vertex
140+
* in `idCol` and each neighbor in `neighborsCol`
141+
* PIC returns a cluster assignment for each input vertex. It appends a new column `predictionCol`
142+
* containing the cluster assignment in `[0,k)` for each row (vertex).
143+
*
144+
* Notes:
145+
* - [[PowerIterationClustering]] is a transformer with an expensive [[transform]] operation.
146+
* Transform runs the iterative PIC algorithm to cluster the whole input dataset.
147+
* - Input validation: This validates that similarities are non-negative but does NOT validate
148+
* that the input matrix is symmetric.
149+
*
150+
* @see <a href=http://en.wikipedia.org/wiki/Spectral_clustering>
151+
* Spectral clustering (Wikipedia)</a>
152+
*/
153+
@Since("2.4.0")
154+
@Experimental
155+
class PowerIterationClustering private[clustering] (
156+
@Since("2.4.0") override val uid: String)
157+
extends Transformer with PowerIterationClusteringParams with DefaultParamsWritable {
158+
159+
setDefault(
160+
k -> 2,
161+
maxIter -> 20,
162+
initMode -> "random")
163+
164+
@Since("2.4.0")
165+
def this() = this(Identifiable.randomUID("PowerIterationClustering"))
166+
167+
/** @group setParam */
168+
@Since("2.4.0")
169+
def setPredictionCol(value: String): this.type = set(predictionCol, value)
170+
171+
/** @group setParam */
172+
@Since("2.4.0")
173+
def setK(value: Int): this.type = set(k, value)
174+
175+
/** @group expertSetParam */
176+
@Since("2.4.0")
177+
def setInitMode(value: String): this.type = set(initMode, value)
178+
179+
/** @group setParam */
180+
@Since("2.4.0")
181+
def setMaxIter(value: Int): this.type = set(maxIter, value)
182+
183+
/** @group setParam */
184+
@Since("2.4.0")
185+
def setIdCol(value: String): this.type = set(idCol, value)
186+
187+
/** @group setParam */
188+
@Since("2.4.0")
189+
def setNeighborsCol(value: String): this.type = set(neighborsCol, value)
190+
191+
/** @group setParam */
192+
@Since("2.4.0")
193+
def setSimilaritiesCol(value: String): this.type = set(similaritiesCol, value)
194+
195+
@Since("2.4.0")
196+
override def transform(dataset: Dataset[_]): DataFrame = {
197+
transformSchema(dataset.schema, logging = true)
198+
199+
val sparkSession = dataset.sparkSession
200+
val idColValue = $(idCol)
201+
val rdd: RDD[(Long, Long, Double)] =
202+
dataset.select(
203+
col($(idCol)).cast(LongType),
204+
col($(neighborsCol)).cast(ArrayType(LongType, containsNull = false)),
205+
col($(similaritiesCol)).cast(ArrayType(DoubleType, containsNull = false))
206+
).rdd.flatMap {
207+
case Row(id: Long, nbrs: Seq[_], sims: Seq[_]) =>
208+
require(nbrs.size == sims.size, s"The length of the neighbor ID list must be " +
209+
s"equal to the the length of the neighbor similarity list. Row for ID " +
210+
s"$idColValue=$id has neighbor ID list of length ${nbrs.length} but similarity list " +
211+
s"of length ${sims.length}.")
212+
nbrs.asInstanceOf[Seq[Long]].zip(sims.asInstanceOf[Seq[Double]]).map {
213+
case (nbr, similarity) => (id, nbr, similarity)
214+
}
215+
}
216+
val algorithm = new MLlibPowerIterationClustering()
217+
.setK($(k))
218+
.setInitializationMode($(initMode))
219+
.setMaxIterations($(maxIter))
220+
val model = algorithm.run(rdd)
221+
222+
val predictionsRDD: RDD[Row] = model.assignments.map { assignment =>
223+
Row(assignment.id, assignment.cluster)
224+
}
225+
226+
val predictionsSchema = StructType(Seq(
227+
StructField($(idCol), LongType, nullable = false),
228+
StructField($(predictionCol), IntegerType, nullable = false)))
229+
val predictions = {
230+
val uncastPredictions = sparkSession.createDataFrame(predictionsRDD, predictionsSchema)
231+
dataset.schema($(idCol)).dataType match {
232+
case _: LongType =>
233+
uncastPredictions
234+
case otherType =>
235+
uncastPredictions.select(col($(idCol)).cast(otherType).alias($(idCol)))
236+
}
237+
}
238+
239+
dataset.join(predictions, $(idCol))
240+
}
241+
242+
@Since("2.4.0")
243+
override def transformSchema(schema: StructType): StructType = {
244+
validateAndTransformSchema(schema)
245+
}
246+
247+
@Since("2.4.0")
248+
override def copy(extra: ParamMap): PowerIterationClustering = defaultCopy(extra)
249+
}
250+
251+
@Since("2.4.0")
252+
object PowerIterationClustering extends DefaultParamsReadable[PowerIterationClustering] {
253+
254+
@Since("2.4.0")
255+
override def load(path: String): PowerIterationClustering = super.load(path)
256+
}

0 commit comments

Comments
 (0)