Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.scheduler

import java.nio.ByteBuffer
import java.util.concurrent.RejectedExecutionException

import scala.language.existentials
import scala.util.control.NonFatal
Expand Down Expand Up @@ -95,25 +96,30 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
def enqueueFailedTask(taskSetManager: TaskSetManager, tid: Long, taskState: TaskState,
serializedData: ByteBuffer) {
var reason : TaskEndReason = UnknownReason
getTaskResultExecutor.execute(new Runnable {
override def run(): Unit = Utils.logUncaughtExceptions {
try {
if (serializedData != null && serializedData.limit() > 0) {
reason = serializer.get().deserialize[TaskEndReason](
serializedData, Utils.getSparkClassLoader)
try {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not related change, it mute the exception when you cancel a job.

getTaskResultExecutor.execute(new Runnable {
override def run(): Unit = Utils.logUncaughtExceptions {
try {
if (serializedData != null && serializedData.limit() > 0) {
reason = serializer.get().deserialize[TaskEndReason](
serializedData, Utils.getSparkClassLoader)
}
} catch {
case cnd: ClassNotFoundException =>
// Log an error but keep going here -- the task failed, so not catastrophic
// if we can't deserialize the reason.
val loader = Utils.getContextOrSparkClassLoader
logError(
"Could not deserialize TaskEndReason: ClassNotFound with classloader " + loader)
case ex: Exception => {}
}
} catch {
case cnd: ClassNotFoundException =>
// Log an error but keep going here -- the task failed, so not catastrophic if we can't
// deserialize the reason.
val loader = Utils.getContextOrSparkClassLoader
logError(
"Could not deserialize TaskEndReason: ClassNotFound with classloader " + loader)
case ex: Exception => {}
scheduler.handleFailedTask(taskSetManager, tid, taskState, reason)
}
scheduler.handleFailedTask(taskSetManager, tid, taskState, reason)
}
})
})
} catch {
case e: RejectedExecutionException if sparkEnv.isStopped =>
// ignore it
}
}

def stop() {
Expand Down
67 changes: 67 additions & 0 deletions examples/src/main/python/status_api_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import time
import threading
import Queue

from pyspark import SparkConf, SparkContext


def delayed(seconds):
def f(x):
time.sleep(seconds)
return x
return f


def call_in_background(f, *args):
result = Queue.Queue(1)
t = threading.Thread(target=lambda: result.put(f(*args)))
t.daemon = True
t.start()
return result


def main():
conf = SparkConf().set("spark.ui.showConsoleProgress", "false")
sc = SparkContext(appName="PythonStatusAPIDemo", conf=conf)

def run():
rdd = sc.parallelize(range(10), 10).map(delayed(2))
reduced = rdd.map(lambda x: (x, 1)).reduceByKey(lambda x, y: x + y)
return reduced.map(delayed(2)).collect()

result = call_in_background(run)
status = sc.statusTracker()
while result.empty():
ids = status.getJobIdsForGroup()
for id in ids:
job = status.getJobInfo(id)
print "Job", id, "status: ", job.status
for sid in job.stageIds:
info = status.getStageInfo(sid)
if info:
print "Stage %d: %d tasks total (%d active, %d complete)" % \
(sid, info.numTasks, info.numActiveTasks, info.numCompletedTasks)
time.sleep(1)

print "Job results are:", result.get()
sc.stop()

if __name__ == "__main__":
main()
15 changes: 8 additions & 7 deletions python/pyspark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,17 @@

- :class:`SparkContext`:
Main entry point for Spark functionality.
- L{RDD}
- :class:`RDD`:
A Resilient Distributed Dataset (RDD), the basic abstraction in Spark.
- L{Broadcast}
- :class:`Broadcast`:
A broadcast variable that gets reused across tasks.
- L{Accumulator}
- :class:`Accumulator`:
An "add-only" shared variable that tasks can only add values to.
- L{SparkConf}
- :class:`SparkConf`:
For configuring Spark.
- L{SparkFiles}
- :class:`SparkFiles`:
Access files shipped with jobs.
- L{StorageLevel}
- :class:`StorageLevel`:
Finer-grained cache persistence levels.

"""
Expand All @@ -45,6 +45,7 @@
from pyspark.accumulators import Accumulator, AccumulatorParam
from pyspark.broadcast import Broadcast
from pyspark.serializers import MarshalSerializer, PickleSerializer
from pyspark.status import *
from pyspark.profiler import Profiler, BasicProfiler

# for back compatibility
Expand All @@ -53,5 +54,5 @@
__all__ = [
"SparkConf", "SparkContext", "SparkFiles", "RDD", "StorageLevel", "Broadcast",
"Accumulator", "AccumulatorParam", "MarshalSerializer", "PickleSerializer",
"Profiler", "BasicProfiler",
"StatusTracker", "SparkJobInfo", "SparkStageInfo", "Profiler", "BasicProfiler",
]
7 changes: 7 additions & 0 deletions python/pyspark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from pyspark.storagelevel import StorageLevel
from pyspark.rdd import RDD
from pyspark.traceback_utils import CallSite, first_spark_call
from pyspark.status import StatusTracker
from pyspark.profiler import ProfilerCollector, BasicProfiler

from py4j.java_collections import ListConverter
Expand Down Expand Up @@ -808,6 +809,12 @@ def cancelAllJobs(self):
"""
self._jsc.sc().cancelAllJobs()

def statusTracker(self):
"""
Return :class:`StatusTracker` object
"""
return StatusTracker(self._jsc.statusTracker())

def runJob(self, rdd, partitionFunc, partitions=None, allowLocal=False):
"""
Executes the given partitionFunc on the specified set of partitions,
Expand Down
96 changes: 96 additions & 0 deletions python/pyspark/status.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from collections import namedtuple

__all__ = ["SparkJobInfo", "SparkStageInfo", "StatusTracker"]


class SparkJobInfo(namedtuple("SparkJobInfo", "jobId stageIds status")):
"""
Exposes information about Spark Jobs.
"""


class SparkStageInfo(namedtuple("SparkStageInfo",
"stageId currentAttemptId name numTasks numActiveTasks "
"numCompletedTasks numFailedTasks")):
"""
Exposes information about Spark Stages.
"""


class StatusTracker(object):
"""
Low-level status reporting APIs for monitoring job and stage progress.

These APIs intentionally provide very weak consistency semantics;
consumers of these APIs should be prepared to handle empty / missing
information. For example, a job's stage ids may be known but the status
API may not have any information about the details of those stages, so
`getStageInfo` could potentially return `None` for a valid stage id.

To limit memory usage, these APIs only provide information on recent
jobs / stages. These APIs will provide information for the last
`spark.ui.retainedStages` stages and `spark.ui.retainedJobs` jobs.
"""
def __init__(self, jtracker):
self._jtracker = jtracker

def getJobIdsForGroup(self, jobGroup=None):
"""
Return a list of all known jobs in a particular job group. If
`jobGroup` is None, then returns all known jobs that are not
associated with a job group.

The returned list may contain running, failed, and completed jobs,
and may vary across invocations of this method. This method does
not guarantee the order of the elements in its result.
"""
return list(self._jtracker.getJobIdsForGroup(jobGroup))

def getActiveStageIds(self):
"""
Returns an array containing the ids of all active stages.
"""
return sorted(list(self._jtracker.getActiveStageIds()))

def getActiveJobsIds(self):
"""
Returns an array containing the ids of all active jobs.
"""
return sorted((list(self._jtracker.getActiveJobIds())))

def getJobInfo(self, jobId):
"""
Returns a :class:`SparkJobInfo` object, or None if the job info
could not be found or was garbage collected.
"""
job = self._jtracker.getJobInfo(jobId)
if job is not None:
return SparkJobInfo(jobId, job.stageIds(), str(job.status()))

def getStageInfo(self, stageId):
"""
Returns a :class:`SparkStageInfo` object, or None if the stage
info could not be found or was garbage collected.
"""
stage = self._jtracker.getStageInfo(stageId)
if stage is not None:
# TODO: fetch them in batch for better performance
attrs = [getattr(stage, f)() for f in SparkStageInfo._fields[1:]]
return SparkStageInfo(stageId, *attrs)
31 changes: 31 additions & 0 deletions python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1550,6 +1550,37 @@ def test_with_stop(self):
sc.stop()
self.assertEqual(SparkContext._active_spark_context, None)

def test_progress_api(self):
with SparkContext() as sc:
sc.setJobGroup('test_progress_api', '', True)

rdd = sc.parallelize(range(10)).map(lambda x: time.sleep(100))
t = threading.Thread(target=rdd.collect)
t.daemon = True
t.start()
# wait for scheduler to start
time.sleep(1)

tracker = sc.statusTracker()
jobIds = tracker.getJobIdsForGroup('test_progress_api')
self.assertEqual(1, len(jobIds))
job = tracker.getJobInfo(jobIds[0])
self.assertEqual(1, len(job.stageIds))
stage = tracker.getStageInfo(job.stageIds[0])
self.assertEqual(rdd.getNumPartitions(), stage.numTasks)

sc.cancelAllJobs()
t.join()
# wait for event listener to update the status
time.sleep(1)

job = tracker.getJobInfo(jobIds[0])
self.assertEqual('FAILED', job.status)
self.assertEqual([], tracker.getActiveJobsIds())
self.assertEqual([], tracker.getActiveStageIds())

sc.stop()


@unittest.skipIf(not _have_scipy, "SciPy not installed")
class SciPyTests(PySparkTestCase):
Expand Down