@@ -26,7 +26,7 @@ import scala.concurrent.{Awaitable, Await, Future}
2626import scala .language .postfixOps
2727
2828import org .apache .spark .{SecurityManager , SparkConf }
29- import org .apache .spark .util .{ThreadUtils , RpcUtils , Utils }
29+ import org .apache .spark .util .{RpcUtils , Utils }
3030
3131
3232/**
@@ -190,8 +190,8 @@ private[spark] object RpcAddress {
190190/**
191191 * An exception thrown if RpcTimeout modifies a [[TimeoutException ]].
192192 */
193- private [rpc] class RpcTimeoutException (message : String )
194- extends TimeoutException (message)
193+ private [rpc] class RpcTimeoutException (message : String , cause : TimeoutException )
194+ extends TimeoutException (message) { initCause(cause) }
195195
196196
197197/**
@@ -209,27 +209,23 @@ private[spark] class RpcTimeout(timeout: FiniteDuration, description: String) {
209209 def message : String = description
210210
211211 /** Amends the standard message of TimeoutException to include the description */
212- def createRpcTimeoutException (te : TimeoutException ): RpcTimeoutException = {
213- new RpcTimeoutException (te.getMessage() + " " + description)
212+ private def createRpcTimeoutException (te : TimeoutException ): RpcTimeoutException = {
213+ new RpcTimeoutException (te.getMessage() + " " + description, te )
214214 }
215215
216216 /**
217- * Add a callback to the given Future so that if it completes as failed with a TimeoutException
218- * then the timeout description is added to the message
217+ * PartialFunction to match a TimeoutException and add the timeout description to the message
218+ *
219+ * @note This can be used in the recover callback of a Future to add to a TimeoutException
220+ * Example:
221+ * val timeout = new RpcTimeout(5 millis, "short timeout")
222+ * Future(throw new TimeoutException).recover(timeout.addMessageIfTimeout)
219223 */
220- def addMessageIfTimeout [T ](future : Future [T ]): Future [T ] = {
221- future.recover {
222- // Add a warning message if Future is passed to addMessageIfTimeoutTest more than once
223- case rte : RpcTimeoutException => throw new RpcTimeoutException (rte.getMessage() +
224- " (Future has multiple calls to RpcTimeout.addMessageIfTimeoutTest)" )
225- // Any other TimeoutException get converted to a RpcTimeoutException with modified message
226- case te : TimeoutException => throw createRpcTimeoutException(te)
227- }(ThreadUtils .sameThread)
228- }
229-
230- /** Applies the duration to create future before calling addMessageIfTimeout*/
231- def addMessageIfTimeout [T ](f : FiniteDuration => Future [T ]): Future [T ] = {
232- addMessageIfTimeout(f(duration))
224+ def addMessageIfTimeout [T ]: PartialFunction [Throwable , T ] = {
225+ // The exception has already been converted to a RpcTimeoutException so just raise it
226+ case rte : RpcTimeoutException => throw rte
227+ // Any other TimeoutException get converted to a RpcTimeoutException with modified message
228+ case te : TimeoutException => throw createRpcTimeoutException(te)
233229 }
234230
235231 /**
@@ -241,13 +237,7 @@ private[spark] class RpcTimeout(timeout: FiniteDuration, description: String) {
241237 def awaitResult [T ](awaitable : Awaitable [T ]): T = {
242238 try {
243239 Await .result(awaitable, duration)
244- }
245- catch {
246- // The exception has already been converted to a RpcTimeoutException so just raise it
247- case rte : RpcTimeoutException => throw rte
248- // Any other TimeoutException get converted to a RpcTimeoutException with modified message
249- case te : TimeoutException => throw createRpcTimeoutException(te)
250- }
240+ } catch addMessageIfTimeout
251241 }
252242}
253243
@@ -299,13 +289,10 @@ object RpcTimeout {
299289
300290 // Find the first set property or use the default value with the first property
301291 val itr = timeoutPropList.iterator
302- var foundProp = None : Option [(String , String )]
292+ var foundProp : Option [(String , String )] = None
303293 while (itr.hasNext && foundProp.isEmpty){
304294 val propKey = itr.next()
305- conf.getOption(propKey) match {
306- case Some (prop) => foundProp = Some (propKey,prop)
307- case None =>
308- }
295+ conf.getOption(propKey).foreach { prop => foundProp = Some (propKey, prop) }
309296 }
310297 val finalProp = foundProp.getOrElse(timeoutPropList.head, defaultValue)
311298 val timeout = { Utils .timeStringAsSeconds(finalProp._2) seconds }
0 commit comments