diff --git a/LICENSE b/LICENSE index 820f14dbdeed..cc1f580207a7 100644 --- a/LICENSE +++ b/LICENSE @@ -237,6 +237,7 @@ The text of each license is also included at licenses/LICENSE-[project].txt. (BSD 3 Clause) netlib core (com.github.fommil.netlib:core:1.1.2 - https://github.com/fommil/netlib-java/core) (BSD 3 Clause) JPMML-Model (org.jpmml:pmml-model:1.2.7 - https://github.com/jpmml/jpmml-model) + (BSD 3 Clause) jmock (org.jmock:jmock-junit4:2.8.4 - http://jmock.org/) (BSD License) AntLR Parser Generator (antlr:antlr:2.7.7 - http://www.antlr.org/) (BSD License) ANTLR 4.5.2-1 (org.antlr:antlr4:4.5.2-1 - http://wwww.antlr.org/) (BSD licence) ANTLR ST4 4.0.4 (org.antlr:ST4:4.0.4 - http://www.stringtemplate.org) diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 5f0045507aaa..9a767dd739b9 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -703,7 +703,7 @@ public boolean append(Object kbase, long koff, int klen, Object vbase, long voff // must be stored in the same memory page. // (8 byte key length) (key) (value) (8 byte pointer to next value) int uaoSize = UnsafeAlignedOffset.getUaoSize(); - final long recordLength = (2 * uaoSize) + klen + vlen + 8; + final long recordLength = (2L * uaoSize) + klen + vlen + 8; if (currentPage == null || currentPage.size() - pageCursor < recordLength) { if (!acquireNewPage(recordLength + uaoSize)) { return false; diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index 066275e8f842..56f3f59504a7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -124,7 +124,7 @@ class HistoryServer( attachHandler(ApiRootResource.getServletHandler(this)) - attachHandler(createStaticHandler(SparkUI.STATIC_RESOURCE_DIR, "/static")) + addStaticHandler(SparkUI.STATIC_RESOURCE_DIR) val contextHandler = new ServletContextHandler contextHandler.setContextPath(HistoryServer.UI_PATH_PREFIX) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala index 35b7ddd46e4d..e87b2240564b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala @@ -43,7 +43,7 @@ class MasterWebUI( val masterPage = new MasterPage(this) attachPage(new ApplicationPage(this)) attachPage(masterPage) - attachHandler(createStaticHandler(MasterWebUI.STATIC_RESOURCE_DIR, "/static")) + addStaticHandler(MasterWebUI.STATIC_RESOURCE_DIR) attachHandler(createRedirectHandler( "/app/kill", "/", masterPage.handleAppKillRequest, httpMethods = Set("POST"))) attachHandler(createRedirectHandler( diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala index 58a181128eb4..a6d13d12fc28 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala @@ -225,7 +225,7 @@ private[deploy] class DriverRunner( // check if attempting another run keepTrying = supervise && exitCode != 0 && !killed if (keepTrying) { - if (clock.getTimeMillis() - processStart > successfulRunDuration * 1000) { + if (clock.getTimeMillis() - processStart > successfulRunDuration * 1000L) { waitSeconds = 1 } logInfo(s"Command exited with status $exitCode, re-launching after $waitSeconds s.") diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala index db696b04384b..ea67b7434a76 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala @@ -47,7 +47,7 @@ class WorkerWebUI( val logPage = new LogPage(this) attachPage(logPage) attachPage(new WorkerPage(this)) - attachHandler(createStaticHandler(WorkerWebUI.STATIC_RESOURCE_BASE, "/static")) + addStaticHandler(WorkerWebUI.STATIC_RESOURCE_BASE) attachHandler(createServletHandler("/log", (request: HttpServletRequest) => logPage.renderLog(request), worker.securityMgr, diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala index 13db4985b0b8..ba9dae4ad48e 100644 --- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala @@ -95,7 +95,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi // the left side of max is >=1 whenever partsScanned >= 2 numPartsToTry = Math.max(1, (1.5 * num * partsScanned / results.size).toInt - partsScanned) - numPartsToTry = Math.min(numPartsToTry, partsScanned * 4) + numPartsToTry = Math.min(numPartsToTry, partsScanned * 4L) } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index e0276a4dc422..df1a4bef616b 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -291,7 +291,7 @@ private[spark] class BlockManager( case e: Exception if i < MAX_ATTEMPTS => logError(s"Failed to connect to external shuffle server, will retry ${MAX_ATTEMPTS - i}" + s" more times after waiting $SLEEP_TIME_SECS seconds...", e) - Thread.sleep(SLEEP_TIME_SECS * 1000) + Thread.sleep(SLEEP_TIME_SECS * 1000L) case NonFatal(e) => throw new SparkException("Unable to register with external shuffle server due to : " + e.getMessage, e) diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index b44ac0ea1feb..d315ef66e0dc 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -65,7 +65,7 @@ private[spark] class SparkUI private ( attachTab(new StorageTab(this, store)) attachTab(new EnvironmentTab(this, store)) attachTab(new ExecutorsTab(this)) - attachHandler(createStaticHandler(SparkUI.STATIC_RESOURCE_DIR, "/static")) + addStaticHandler(SparkUI.STATIC_RESOURCE_DIR) attachHandler(createRedirectHandler("/", "/jobs/", basePath = basePath)) attachHandler(ApiRootResource.getServletHandler(this)) diff --git a/core/src/main/scala/org/apache/spark/ui/WebUI.scala b/core/src/main/scala/org/apache/spark/ui/WebUI.scala index 8b75f5d8fe1a..2e43f17e6a8e 100644 --- a/core/src/main/scala/org/apache/spark/ui/WebUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/WebUI.scala @@ -60,23 +60,25 @@ private[spark] abstract class WebUI( def getHandlers: Seq[ServletContextHandler] = handlers def getSecurityManager: SecurityManager = securityManager - /** Attach a tab to this UI, along with all of its attached pages. */ - def attachTab(tab: WebUITab) { + /** Attaches a tab to this UI, along with all of its attached pages. */ + def attachTab(tab: WebUITab): Unit = { tab.pages.foreach(attachPage) tabs += tab } - def detachTab(tab: WebUITab) { + /** Detaches a tab from this UI, along with all of its attached pages. */ + def detachTab(tab: WebUITab): Unit = { tab.pages.foreach(detachPage) tabs -= tab } - def detachPage(page: WebUIPage) { + /** Detaches a page from this UI, along with all of its attached handlers. */ + def detachPage(page: WebUIPage): Unit = { pageToHandlers.remove(page).foreach(_.foreach(detachHandler)) } - /** Attach a page to this UI. */ - def attachPage(page: WebUIPage) { + /** Attaches a page to this UI. */ + def attachPage(page: WebUIPage): Unit = { val pagePath = "/" + page.prefix val renderHandler = createServletHandler(pagePath, (request: HttpServletRequest) => page.render(request), securityManager, conf, basePath) @@ -88,41 +90,41 @@ private[spark] abstract class WebUI( handlers += renderHandler } - /** Attach a handler to this UI. */ - def attachHandler(handler: ServletContextHandler) { + /** Attaches a handler to this UI. */ + def attachHandler(handler: ServletContextHandler): Unit = { handlers += handler serverInfo.foreach(_.addHandler(handler)) } - /** Detach a handler from this UI. */ - def detachHandler(handler: ServletContextHandler) { + /** Detaches a handler from this UI. */ + def detachHandler(handler: ServletContextHandler): Unit = { handlers -= handler serverInfo.foreach(_.removeHandler(handler)) } /** - * Add a handler for static content. + * Detaches the content handler at `path` URI. * - * @param resourceBase Root of where to find resources to serve. - * @param path Path in UI where to mount the resources. + * @param path Path in UI to unmount. */ - def addStaticHandler(resourceBase: String, path: String): Unit = { - attachHandler(JettyUtils.createStaticHandler(resourceBase, path)) + def detachHandler(path: String): Unit = { + handlers.find(_.getContextPath() == path).foreach(detachHandler) } /** - * Remove a static content handler. + * Adds a handler for static content. * - * @param path Path in UI to unmount. + * @param resourceBase Root of where to find resources to serve. + * @param path Path in UI where to mount the resources. */ - def removeStaticHandler(path: String): Unit = { - handlers.find(_.getContextPath() == path).foreach(detachHandler) + def addStaticHandler(resourceBase: String, path: String = "/static"): Unit = { + attachHandler(JettyUtils.createStaticHandler(resourceBase, path)) } - /** Initialize all components of the server. */ + /** A hook to initialize components of the UI */ def initialize(): Unit - /** Bind to the HTTP server behind this web interface. */ + /** Binds to the HTTP server behind this web interface. */ def bind(): Unit = { assert(serverInfo.isEmpty, s"Attempted to bind $className more than once!") try { @@ -136,17 +138,17 @@ private[spark] abstract class WebUI( } } - /** Return the url of web interface. Only valid after bind(). */ + /** @return The url of web interface. Only valid after [[bind]]. */ def webUrl: String = s"http://$publicHostName:$boundPort" - /** Return the actual port to which this server is bound. Only valid after bind(). */ + /** @return The actual port to which this server is bound. Only valid after [[bind]]. */ def boundPort: Int = serverInfo.map(_.boundPort).getOrElse(-1) - /** Stop the server behind this web interface. Only valid after bind(). */ + /** Stops the server behind this web interface. Only valid after [[bind]]. */ def stop(): Unit = { assert(serverInfo.isDefined, s"Attempted to stop $className before binding to a server!") - serverInfo.get.stop() + serverInfo.foreach(_.stop()) } } diff --git a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala index 165a15c73e7c..0f08a2b0ad89 100644 --- a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala @@ -19,13 +19,12 @@ package org.apache.spark.util import java.util.concurrent._ +import com.google.common.util.concurrent.{MoreExecutors, ThreadFactoryBuilder} import scala.concurrent.{Awaitable, ExecutionContext, ExecutionContextExecutor} -import scala.concurrent.duration.Duration +import scala.concurrent.duration.{Duration, FiniteDuration} import scala.concurrent.forkjoin.{ForkJoinPool => SForkJoinPool, ForkJoinWorkerThread => SForkJoinWorkerThread} import scala.util.control.NonFatal -import com.google.common.util.concurrent.{MoreExecutors, ThreadFactoryBuilder} - import org.apache.spark.SparkException private[spark] object ThreadUtils { @@ -103,6 +102,22 @@ private[spark] object ThreadUtils { executor } + /** + * Wrapper over ScheduledThreadPoolExecutor. + */ + def newDaemonThreadPoolScheduledExecutor(threadNamePrefix: String, numThreads: Int) + : ScheduledExecutorService = { + val threadFactory = new ThreadFactoryBuilder() + .setDaemon(true) + .setNameFormat(s"$threadNamePrefix-%d") + .build() + val executor = new ScheduledThreadPoolExecutor(numThreads, threadFactory) + // By default, a cancelled task is not automatically removed from the work queue until its delay + // elapses. We have to enable it manually. + executor.setRemoveOnCancelPolicy(true) + executor + } + /** * Run a piece of code in a new thread and return the result. Exception in the new thread is * thrown in the caller thread with an adjusted stack trace that removes references to this @@ -229,4 +244,14 @@ private[spark] object ThreadUtils { } } // scalastyle:on awaitready + + def shutdown( + executor: ExecutorService, + gracePeriod: Duration = FiniteDuration(30, TimeUnit.SECONDS)): Unit = { + executor.shutdown() + executor.awaitTermination(gracePeriod.toMillis, TimeUnit.MILLISECONDS) + if (!executor.isShutdown) { + executor.shutdownNow() + } + } } diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java index c14553232851..85ffdca436e1 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java @@ -129,7 +129,6 @@ public int compare( final UnsafeSorterIterator iter = sorter.getSortedIterator(); int iterLength = 0; long prevPrefix = -1; - Arrays.sort(dataToSort); while (iter.hasNext()) { iter.loadNext(); final String str = diff --git a/launcher/src/main/java/org/apache/spark/launcher/Main.java b/launcher/src/main/java/org/apache/spark/launcher/Main.java index 1e34bb8c7327..d967aa39a482 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/Main.java +++ b/launcher/src/main/java/org/apache/spark/launcher/Main.java @@ -17,6 +17,7 @@ package org.apache.spark.launcher; +import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; @@ -54,10 +55,12 @@ public static void main(String[] argsArray) throws Exception { String className = args.remove(0); boolean printLaunchCommand = !isEmpty(System.getenv("SPARK_PRINT_LAUNCH_COMMAND")); - AbstractCommandBuilder builder; + Map env = new HashMap<>(); + List cmd; if (className.equals("org.apache.spark.deploy.SparkSubmit")) { try { - builder = new SparkSubmitCommandBuilder(args); + AbstractCommandBuilder builder = new SparkSubmitCommandBuilder(args); + cmd = buildCommand(builder, env, printLaunchCommand); } catch (IllegalArgumentException e) { printLaunchCommand = false; System.err.println("Error: " + e.getMessage()); @@ -76,17 +79,12 @@ public static void main(String[] argsArray) throws Exception { help.add(parser.className); } help.add(parser.USAGE_ERROR); - builder = new SparkSubmitCommandBuilder(help); + AbstractCommandBuilder builder = new SparkSubmitCommandBuilder(help); + cmd = buildCommand(builder, env, printLaunchCommand); } } else { - builder = new SparkClassCommandBuilder(className, args); - } - - Map env = new HashMap<>(); - List cmd = builder.buildCommand(env); - if (printLaunchCommand) { - System.err.println("Spark Command: " + join(" ", cmd)); - System.err.println("========================================"); + AbstractCommandBuilder builder = new SparkClassCommandBuilder(className, args); + cmd = buildCommand(builder, env, printLaunchCommand); } if (isWindows()) { @@ -101,6 +99,22 @@ public static void main(String[] argsArray) throws Exception { } } + /** + * Prepare spark commands with the appropriate command builder. + * If printLaunchCommand is set then the commands will be printed to the stderr. + */ + private static List buildCommand( + AbstractCommandBuilder builder, + Map env, + boolean printLaunchCommand) throws IOException, IllegalArgumentException { + List cmd = builder.buildCommand(env); + if (printLaunchCommand) { + System.err.println("Spark Command: " + join(" ", cmd)); + System.err.println("========================================"); + } + return cmd; + } + /** * Prepare a command line for execution from a Windows batch script. * diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java index 5cb6457bf5c2..cc65f78b45c3 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java @@ -90,7 +90,8 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { final List userArgs; private final List parsedArgs; - private final boolean requiresAppResource; + // Special command means no appResource and no mainClass required + private final boolean isSpecialCommand; private final boolean isExample; /** @@ -105,7 +106,7 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { * spark-submit argument list to be modified after creation. */ SparkSubmitCommandBuilder() { - this.requiresAppResource = true; + this.isSpecialCommand = false; this.isExample = false; this.parsedArgs = new ArrayList<>(); this.userArgs = new ArrayList<>(); @@ -138,25 +139,26 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { case RUN_EXAMPLE: isExample = true; + appResource = SparkLauncher.NO_RESOURCE; submitArgs = args.subList(1, args.size()); } this.isExample = isExample; OptionParser parser = new OptionParser(true); parser.parse(submitArgs); - this.requiresAppResource = parser.requiresAppResource; + this.isSpecialCommand = parser.isSpecialCommand; } else { this.isExample = isExample; - this.requiresAppResource = false; + this.isSpecialCommand = true; } } @Override public List buildCommand(Map env) throws IOException, IllegalArgumentException { - if (PYSPARK_SHELL.equals(appResource) && requiresAppResource) { + if (PYSPARK_SHELL.equals(appResource) && !isSpecialCommand) { return buildPySparkShellCommand(env); - } else if (SPARKR_SHELL.equals(appResource) && requiresAppResource) { + } else if (SPARKR_SHELL.equals(appResource) && !isSpecialCommand) { return buildSparkRCommand(env); } else { return buildSparkSubmitCommand(env); @@ -166,18 +168,18 @@ public List buildCommand(Map env) List buildSparkSubmitArgs() { List args = new ArrayList<>(); OptionParser parser = new OptionParser(false); - final boolean requiresAppResource; + final boolean isSpecialCommand; // If the user args array is not empty, we need to parse it to detect exactly what // the user is trying to run, so that checks below are correct. if (!userArgs.isEmpty()) { parser.parse(userArgs); - requiresAppResource = parser.requiresAppResource; + isSpecialCommand = parser.isSpecialCommand; } else { - requiresAppResource = this.requiresAppResource; + isSpecialCommand = this.isSpecialCommand; } - if (!allowsMixedArguments && requiresAppResource) { + if (!allowsMixedArguments && !isSpecialCommand) { checkArgument(appResource != null, "Missing application resource."); } @@ -229,7 +231,7 @@ List buildSparkSubmitArgs() { args.add(join(",", pyFiles)); } - if (isExample) { + if (isExample && !isSpecialCommand) { checkArgument(mainClass != null, "Missing example class name."); } @@ -421,7 +423,7 @@ private List findExamplesJars() { private class OptionParser extends SparkSubmitOptionParser { - boolean requiresAppResource = true; + boolean isSpecialCommand = false; private final boolean errorOnUnknownArgs; OptionParser(boolean errorOnUnknownArgs) { @@ -470,17 +472,14 @@ protected boolean handle(String opt, String value) { break; case KILL_SUBMISSION: case STATUS: - requiresAppResource = false; + isSpecialCommand = true; parsedArgs.add(opt); parsedArgs.add(value); break; case HELP: case USAGE_ERROR: - requiresAppResource = false; - parsedArgs.add(opt); - break; case VERSION: - requiresAppResource = false; + isSpecialCommand = true; parsedArgs.add(opt); break; default: diff --git a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java index 2e050f841307..b343094b2e7b 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java @@ -18,6 +18,7 @@ package org.apache.spark.launcher; import java.io.File; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; @@ -27,7 +28,10 @@ import org.junit.AfterClass; import org.junit.BeforeClass; +import org.junit.Rule; import org.junit.Test; +import org.junit.rules.ExpectedException; + import static org.junit.Assert.*; public class SparkSubmitCommandBuilderSuite extends BaseSuite { @@ -35,6 +39,9 @@ public class SparkSubmitCommandBuilderSuite extends BaseSuite { private static File dummyPropsFile; private static SparkSubmitOptionParser parser; + @Rule + public ExpectedException expectedException = ExpectedException.none(); + @BeforeClass public static void setUp() throws Exception { dummyPropsFile = File.createTempFile("spark", "properties"); @@ -74,8 +81,11 @@ public void testCliHelpAndNoArg() throws Exception { @Test public void testCliKillAndStatus() throws Exception { - testCLIOpts(parser.STATUS); - testCLIOpts(parser.KILL_SUBMISSION); + List params = Arrays.asList("driver-20160531171222-0000"); + testCLIOpts(null, parser.STATUS, params); + testCLIOpts(null, parser.KILL_SUBMISSION, params); + testCLIOpts(SparkSubmitCommandBuilder.RUN_EXAMPLE, parser.STATUS, params); + testCLIOpts(SparkSubmitCommandBuilder.RUN_EXAMPLE, parser.KILL_SUBMISSION, params); } @Test @@ -190,6 +200,33 @@ public void testSparkRShell() throws Exception { env.get("SPARKR_SUBMIT_ARGS")); } + @Test(expected = IllegalArgumentException.class) + public void testExamplesRunnerNoArg() throws Exception { + List sparkSubmitArgs = Arrays.asList(SparkSubmitCommandBuilder.RUN_EXAMPLE); + Map env = new HashMap<>(); + buildCommand(sparkSubmitArgs, env); + } + + @Test + public void testExamplesRunnerNoMainClass() throws Exception { + testCLIOpts(SparkSubmitCommandBuilder.RUN_EXAMPLE, parser.HELP, null); + testCLIOpts(SparkSubmitCommandBuilder.RUN_EXAMPLE, parser.USAGE_ERROR, null); + testCLIOpts(SparkSubmitCommandBuilder.RUN_EXAMPLE, parser.VERSION, null); + } + + @Test + public void testExamplesRunnerWithMasterNoMainClass() throws Exception { + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("Missing example class name."); + + List sparkSubmitArgs = Arrays.asList( + SparkSubmitCommandBuilder.RUN_EXAMPLE, + parser.MASTER + "=foo" + ); + Map env = new HashMap<>(); + buildCommand(sparkSubmitArgs, env); + } + @Test public void testExamplesRunner() throws Exception { List sparkSubmitArgs = Arrays.asList( @@ -344,10 +381,17 @@ private List buildCommand(List args, Map env) th return newCommandBuilder(args).buildCommand(env); } - private void testCLIOpts(String opt) throws Exception { - List helpArgs = Arrays.asList(opt, "driver-20160531171222-0000"); + private void testCLIOpts(String appResource, String opt, List params) throws Exception { + List args = new ArrayList<>(); + if (appResource != null) { + args.add(appResource); + } + args.add(opt); + if (params != null) { + args.addAll(params); + } Map env = new HashMap<>(); - List cmd = buildCommand(helpArgs, env); + List cmd = buildCommand(args, env); assertTrue(opt + " should be contained in the final cmd.", cmd.contains(opt)); } diff --git a/licenses/LICENSE-jmock.txt b/licenses/LICENSE-jmock.txt new file mode 100644 index 000000000000..ed7964fe3d9e --- /dev/null +++ b/licenses/LICENSE-jmock.txt @@ -0,0 +1,28 @@ +Copyright (c) 2000-2017, jMock.org +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + +Redistributions of source code must retain the above copyright notice, +this list of conditions and the following disclaimer. Redistributions +in binary form must reproduce the above copyright notice, this list of +conditions and the following disclaimer in the documentation and/or +other materials provided with the distribution. + +Neither the name of jMock nor the names of its contributors may be +used to endorse or promote products derived from this software without +specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/pom.xml b/pom.xml index 23bbd3b09734..4b4e6c13ea8f 100644 --- a/pom.xml +++ b/pom.xml @@ -760,6 +760,12 @@ 1.10.19 test + + org.jmock + jmock-junit4 + test + 2.8.4 + org.scalacheck scalacheck_${scala.binary.version} diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 15753f77bd90..4c16b5fc26f3 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -33,8 +33,9 @@ [0, 2, 4, 6, 8, 10, 12, 14, 16, 18] >>> sc.stop() -PySpark serialize objects in batches; By default, the batch size is chosen based -on the size of objects, also configurable by SparkContext's C{batchSize} parameter: +PySpark serializes objects in batches; by default, the batch size is chosen based +on the size of objects and is also configurable by SparkContext's C{batchSize} +parameter: >>> sc = SparkContext('local', 'test', batchSize=2) >>> rdd = sc.parallelize(range(16), 4).map(lambda x: x) @@ -100,7 +101,7 @@ def load_stream(self, stream): def _load_stream_without_unbatching(self, stream): """ Return an iterator of deserialized batches (iterable) of objects from the input stream. - if the serializer does not operate on batches the default implementation returns an + If the serializer does not operate on batches the default implementation returns an iterator of single element lists. """ return map(lambda x: [x], self.load_stream(stream)) @@ -461,7 +462,7 @@ def dumps(self, obj): return obj -# Hook namedtuple, make it picklable +# Hack namedtuple, make it picklable __cls = {} @@ -525,15 +526,15 @@ def namedtuple(*args, **kwargs): cls = _old_namedtuple(*args, **kwargs) return _hack_namedtuple(cls) - # replace namedtuple with new one + # replace namedtuple with the new one collections.namedtuple.__globals__["_old_namedtuple_kwdefaults"] = _old_namedtuple_kwdefaults collections.namedtuple.__globals__["_old_namedtuple"] = _old_namedtuple collections.namedtuple.__globals__["_hack_namedtuple"] = _hack_namedtuple collections.namedtuple.__code__ = namedtuple.__code__ collections.namedtuple.__hijack = 1 - # hack the cls already generated by namedtuple - # those created in other module can be pickled as normal, + # hack the cls already generated by namedtuple. + # Those created in other modules can be pickled as normal, # so only hack those in __main__ module for n, o in sys.modules["__main__"].__dict__.items(): if (type(o) is type and o.__base__ is tuple @@ -627,7 +628,7 @@ def loads(self, obj): elif _type == b'P': return pickle.loads(obj[1:]) else: - raise ValueError("invalid sevialization type: %s" % _type) + raise ValueError("invalid serialization type: %s" % _type) class CompressedSerializer(FramedSerializer): diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index a5e3384e802b..e6346691fb1d 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2168,8 +2168,7 @@ def from_json(col, schema, options={}): [Row(json=Row(a=1))] >>> df.select(from_json(df.value, "a INT").alias("json")).collect() [Row(json=Row(a=1))] - >>> schema = MapType(StringType(), IntegerType()) - >>> df.select(from_json(df.value, schema).alias("json")).collect() + >>> df.select(from_json(df.value, "MAP").alias("json")).collect() [Row(json={u'a': 1})] >>> data = [(1, '''[{"a": 1}]''')] >>> schema = ArrayType(StructType([StructField("a", IntegerType())])) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index e880dd1ca6d1..f1ad6b1212ed 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -567,14 +567,7 @@ def _create_shell_session(): .getOrCreate() else: return SparkSession.builder.getOrCreate() - except py4j.protocol.Py4JError: - if conf.get('spark.sql.catalogImplementation', '').lower() == 'hive': - warnings.warn("Fall back to non-hive support because failing to access HiveConf, " - "please make sure you build spark with hive") - - try: - return SparkSession.builder.getOrCreate() - except TypeError: + except (py4j.protocol.Py4JError, TypeError): if conf.get('spark.sql.catalogImplementation', '').lower() == 'hive': warnings.warn("Fall back to non-hive support because failing to access HiveConf, " "please make sure you build spark with hive") diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index fae50b3d5d53..4984593bab49 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -854,6 +854,168 @@ def trigger(self, processingTime=None, once=None, continuous=None): self._jwrite = self._jwrite.trigger(jTrigger) return self + @since(2.4) + def foreach(self, f): + """ + Sets the output of the streaming query to be processed using the provided writer ``f``. + This is often used to write the output of a streaming query to arbitrary storage systems. + The processing logic can be specified in two ways. + + #. A **function** that takes a row as input. + This is a simple way to express your processing logic. Note that this does + not allow you to deduplicate generated data when failures cause reprocessing of + some input data. That would require you to specify the processing logic in the next + way. + + #. An **object** with a ``process`` method and optional ``open`` and ``close`` methods. + The object can have the following methods. + + * ``open(partition_id, epoch_id)``: *Optional* method that initializes the processing + (for example, open a connection, start a transaction, etc). Additionally, you can + use the `partition_id` and `epoch_id` to deduplicate regenerated data + (discussed later). + + * ``process(row)``: *Non-optional* method that processes each :class:`Row`. + + * ``close(error)``: *Optional* method that finalizes and cleans up (for example, + close connection, commit transaction, etc.) after all rows have been processed. + + The object will be used by Spark in the following way. + + * A single copy of this object is responsible of all the data generated by a + single task in a query. In other words, one instance is responsible for + processing one partition of the data generated in a distributed manner. + + * This object must be serializable because each task will get a fresh + serialized-deserialized copy of the provided object. Hence, it is strongly + recommended that any initialization for writing data (e.g. opening a + connection or starting a transaction) is done after the `open(...)` + method has been called, which signifies that the task is ready to generate data. + + * The lifecycle of the methods are as follows. + + For each partition with ``partition_id``: + + ... For each batch/epoch of streaming data with ``epoch_id``: + + ....... Method ``open(partitionId, epochId)`` is called. + + ....... If ``open(...)`` returns true, for each row in the partition and + batch/epoch, method ``process(row)`` is called. + + ....... Method ``close(errorOrNull)`` is called with error (if any) seen while + processing rows. + + Important points to note: + + * The `partitionId` and `epochId` can be used to deduplicate generated data when + failures cause reprocessing of some input data. This depends on the execution + mode of the query. If the streaming query is being executed in the micro-batch + mode, then every partition represented by a unique tuple (partition_id, epoch_id) + is guaranteed to have the same data. Hence, (partition_id, epoch_id) can be used + to deduplicate and/or transactionally commit data and achieve exactly-once + guarantees. However, if the streaming query is being executed in the continuous + mode, then this guarantee does not hold and therefore should not be used for + deduplication. + + * The ``close()`` method (if exists) will be called if `open()` method exists and + returns successfully (irrespective of the return value), except if the Python + crashes in the middle. + + .. note:: Evolving. + + >>> # Print every row using a function + >>> def print_row(row): + ... print(row) + ... + >>> writer = sdf.writeStream.foreach(print_row) + >>> # Print every row using a object with process() method + >>> class RowPrinter: + ... def open(self, partition_id, epoch_id): + ... print("Opened %d, %d" % (partition_id, epoch_id)) + ... return True + ... def process(self, row): + ... print(row) + ... def close(self, error): + ... print("Closed with error: %s" % str(error)) + ... + >>> writer = sdf.writeStream.foreach(RowPrinter()) + """ + + from pyspark.rdd import _wrap_function + from pyspark.serializers import PickleSerializer, AutoBatchedSerializer + from pyspark.taskcontext import TaskContext + + if callable(f): + # The provided object is a callable function that is supposed to be called on each row. + # Construct a function that takes an iterator and calls the provided function on each + # row. + def func_without_process(_, iterator): + for x in iterator: + f(x) + return iter([]) + + func = func_without_process + + else: + # The provided object is not a callable function. Then it is expected to have a + # 'process(row)' method, and optional 'open(partition_id, epoch_id)' and + # 'close(error)' methods. + + if not hasattr(f, 'process'): + raise Exception("Provided object does not have a 'process' method") + + if not callable(getattr(f, 'process')): + raise Exception("Attribute 'process' in provided object is not callable") + + def doesMethodExist(method_name): + exists = hasattr(f, method_name) + if exists and not callable(getattr(f, method_name)): + raise Exception( + "Attribute '%s' in provided object is not callable" % method_name) + return exists + + open_exists = doesMethodExist('open') + close_exists = doesMethodExist('close') + + def func_with_open_process_close(partition_id, iterator): + epoch_id = TaskContext.get().getLocalProperty('streaming.sql.batchId') + if epoch_id: + epoch_id = int(epoch_id) + else: + raise Exception("Could not get batch id from TaskContext") + + # Check if the data should be processed + should_process = True + if open_exists: + should_process = f.open(partition_id, epoch_id) + + error = None + + try: + if should_process: + for x in iterator: + f.process(x) + except Exception as ex: + error = ex + finally: + if close_exists: + f.close(error) + if error: + raise error + + return iter([]) + + func = func_with_open_process_close + + serializer = AutoBatchedSerializer(PickleSerializer()) + wrapped_func = _wrap_function(self._spark._sc, func, serializer, serializer) + jForeachWriter = \ + self._spark._sc._jvm.org.apache.spark.sql.execution.python.PythonForeachWriter( + wrapped_func, self._df._jdf.schema()) + self._jwrite.foreach(jForeachWriter) + return self + @ignore_unicode_prefix @since(2.0) def start(self, path=None, format=None, outputMode=None, partitionBy=None, queryName=None, diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 2d7a4f62d4ee..4e5fafa77e10 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1869,6 +1869,263 @@ def test_query_manager_await_termination(self): q.stop() shutil.rmtree(tmpPath) + class ForeachWriterTester: + + def __init__(self, spark): + self.spark = spark + + def write_open_event(self, partitionId, epochId): + self._write_event( + self.open_events_dir, + {'partition': partitionId, 'epoch': epochId}) + + def write_process_event(self, row): + self._write_event(self.process_events_dir, {'value': 'text'}) + + def write_close_event(self, error): + self._write_event(self.close_events_dir, {'error': str(error)}) + + def write_input_file(self): + self._write_event(self.input_dir, "text") + + def open_events(self): + return self._read_events(self.open_events_dir, 'partition INT, epoch INT') + + def process_events(self): + return self._read_events(self.process_events_dir, 'value STRING') + + def close_events(self): + return self._read_events(self.close_events_dir, 'error STRING') + + def run_streaming_query_on_writer(self, writer, num_files): + self._reset() + try: + sdf = self.spark.readStream.format('text').load(self.input_dir) + sq = sdf.writeStream.foreach(writer).start() + for i in range(num_files): + self.write_input_file() + sq.processAllAvailable() + finally: + self.stop_all() + + def assert_invalid_writer(self, writer, msg=None): + self._reset() + try: + sdf = self.spark.readStream.format('text').load(self.input_dir) + sq = sdf.writeStream.foreach(writer).start() + self.write_input_file() + sq.processAllAvailable() + self.fail("invalid writer %s did not fail the query" % str(writer)) # not expected + except Exception as e: + if msg: + assert(msg in str(e), "%s not in %s" % (msg, str(e))) + + finally: + self.stop_all() + + def stop_all(self): + for q in self.spark._wrapped.streams.active: + q.stop() + + def _reset(self): + self.input_dir = tempfile.mkdtemp() + self.open_events_dir = tempfile.mkdtemp() + self.process_events_dir = tempfile.mkdtemp() + self.close_events_dir = tempfile.mkdtemp() + + def _read_events(self, dir, json): + rows = self.spark.read.schema(json).json(dir).collect() + dicts = [row.asDict() for row in rows] + return dicts + + def _write_event(self, dir, event): + import uuid + with open(os.path.join(dir, str(uuid.uuid4())), 'w') as f: + f.write("%s\n" % str(event)) + + def __getstate__(self): + return (self.open_events_dir, self.process_events_dir, self.close_events_dir) + + def __setstate__(self, state): + self.open_events_dir, self.process_events_dir, self.close_events_dir = state + + def test_streaming_foreach_with_simple_function(self): + tester = self.ForeachWriterTester(self.spark) + + def foreach_func(row): + tester.write_process_event(row) + + tester.run_streaming_query_on_writer(foreach_func, 2) + self.assertEqual(len(tester.process_events()), 2) + + def test_streaming_foreach_with_basic_open_process_close(self): + tester = self.ForeachWriterTester(self.spark) + + class ForeachWriter: + def open(self, partitionId, epochId): + tester.write_open_event(partitionId, epochId) + return True + + def process(self, row): + tester.write_process_event(row) + + def close(self, error): + tester.write_close_event(error) + + tester.run_streaming_query_on_writer(ForeachWriter(), 2) + + open_events = tester.open_events() + self.assertEqual(len(open_events), 2) + self.assertSetEqual(set([e['epoch'] for e in open_events]), {0, 1}) + + self.assertEqual(len(tester.process_events()), 2) + + close_events = tester.close_events() + self.assertEqual(len(close_events), 2) + self.assertSetEqual(set([e['error'] for e in close_events]), {'None'}) + + def test_streaming_foreach_with_open_returning_false(self): + tester = self.ForeachWriterTester(self.spark) + + class ForeachWriter: + def open(self, partition_id, epoch_id): + tester.write_open_event(partition_id, epoch_id) + return False + + def process(self, row): + tester.write_process_event(row) + + def close(self, error): + tester.write_close_event(error) + + tester.run_streaming_query_on_writer(ForeachWriter(), 2) + + self.assertEqual(len(tester.open_events()), 2) + + self.assertEqual(len(tester.process_events()), 0) # no row was processed + + close_events = tester.close_events() + self.assertEqual(len(close_events), 2) + self.assertSetEqual(set([e['error'] for e in close_events]), {'None'}) + + def test_streaming_foreach_without_open_method(self): + tester = self.ForeachWriterTester(self.spark) + + class ForeachWriter: + def process(self, row): + tester.write_process_event(row) + + def close(self, error): + tester.write_close_event(error) + + tester.run_streaming_query_on_writer(ForeachWriter(), 2) + self.assertEqual(len(tester.open_events()), 0) # no open events + self.assertEqual(len(tester.process_events()), 2) + self.assertEqual(len(tester.close_events()), 2) + + def test_streaming_foreach_without_close_method(self): + tester = self.ForeachWriterTester(self.spark) + + class ForeachWriter: + def open(self, partition_id, epoch_id): + tester.write_open_event(partition_id, epoch_id) + return True + + def process(self, row): + tester.write_process_event(row) + + tester.run_streaming_query_on_writer(ForeachWriter(), 2) + self.assertEqual(len(tester.open_events()), 2) # no open events + self.assertEqual(len(tester.process_events()), 2) + self.assertEqual(len(tester.close_events()), 0) + + def test_streaming_foreach_without_open_and_close_methods(self): + tester = self.ForeachWriterTester(self.spark) + + class ForeachWriter: + def process(self, row): + tester.write_process_event(row) + + tester.run_streaming_query_on_writer(ForeachWriter(), 2) + self.assertEqual(len(tester.open_events()), 0) # no open events + self.assertEqual(len(tester.process_events()), 2) + self.assertEqual(len(tester.close_events()), 0) + + def test_streaming_foreach_with_process_throwing_error(self): + from pyspark.sql.utils import StreamingQueryException + + tester = self.ForeachWriterTester(self.spark) + + class ForeachWriter: + def process(self, row): + raise Exception("test error") + + def close(self, error): + tester.write_close_event(error) + + try: + tester.run_streaming_query_on_writer(ForeachWriter(), 1) + self.fail("bad writer did not fail the query") # this is not expected + except StreamingQueryException as e: + # TODO: Verify whether original error message is inside the exception + pass + + self.assertEqual(len(tester.process_events()), 0) # no row was processed + close_events = tester.close_events() + self.assertEqual(len(close_events), 1) + # TODO: Verify whether original error message is inside the exception + + def test_streaming_foreach_with_invalid_writers(self): + + tester = self.ForeachWriterTester(self.spark) + + def func_with_iterator_input(iter): + for x in iter: + print(x) + + tester.assert_invalid_writer(func_with_iterator_input) + + class WriterWithoutProcess: + def open(self, partition): + pass + + tester.assert_invalid_writer(WriterWithoutProcess(), "does not have a 'process'") + + class WriterWithNonCallableProcess(): + process = True + + tester.assert_invalid_writer(WriterWithNonCallableProcess(), + "'process' in provided object is not callable") + + class WriterWithNoParamProcess(): + def process(self): + pass + + tester.assert_invalid_writer(WriterWithNoParamProcess()) + + # Abstract class for tests below + class WithProcess(): + def process(self, row): + pass + + class WriterWithNonCallableOpen(WithProcess): + open = True + + tester.assert_invalid_writer(WriterWithNonCallableOpen(), + "'open' in provided object is not callable") + + class WriterWithNoParamOpen(WithProcess): + def open(self): + pass + + tester.assert_invalid_writer(WriterWithNoParamOpen()) + + class WriterWithNonCallableClose(WithProcess): + close = True + + tester.assert_invalid_writer(WriterWithNonCallableClose(), + "'close' in provided object is not callable") + def test_help_command(self): # Regression test for SPARK-5464 rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 18b2f251dc9f..a4c5fb1db8b3 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -581,9 +581,9 @@ def test_get_local_property(self): self.sc.setLocalProperty(key, value) try: rdd = self.sc.parallelize(range(1), 1) - prop1 = rdd.map(lambda x: TaskContext.get().getLocalProperty(key)).collect()[0] + prop1 = rdd.map(lambda _: TaskContext.get().getLocalProperty(key)).collect()[0] self.assertEqual(prop1, value) - prop2 = rdd.map(lambda x: TaskContext.get().getLocalProperty("otherkey")).collect()[0] + prop2 = rdd.map(lambda _: TaskContext.get().getLocalProperty("otherkey")).collect()[0] self.assertTrue(prop2 is None) finally: self.sc.setLocalProperty(key, None) diff --git a/resource-managers/kubernetes/core/pom.xml b/resource-managers/kubernetes/core/pom.xml index a62f27127346..a6dd47a6b7d9 100644 --- a/resource-managers/kubernetes/core/pom.xml +++ b/resource-managers/kubernetes/core/pom.xml @@ -77,6 +77,12 @@ + + com.squareup.okhttp3 + okhttp + 3.8.1 + + org.mockito mockito-core @@ -84,9 +90,9 @@ - com.squareup.okhttp3 - okhttp - 3.8.1 + org.jmock + jmock-junit4 + test diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala index 590deaa72e7e..bf33179ae3da 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala @@ -176,6 +176,24 @@ private[spark] object Config extends Logging { .checkValue(interval => interval > 0, s"Logging interval must be a positive time value.") .createWithDefaultString("1s") + val KUBERNETES_EXECUTOR_API_POLLING_INTERVAL = + ConfigBuilder("spark.kubernetes.executor.apiPollingInterval") + .doc("Interval between polls against the Kubernetes API server to inspect the " + + "state of executors.") + .timeConf(TimeUnit.MILLISECONDS) + .checkValue(interval => interval > 0, s"API server polling interval must be a" + + " positive time value.") + .createWithDefaultString("30s") + + val KUBERNETES_EXECUTOR_EVENT_PROCESSING_INTERVAL = + ConfigBuilder("spark.kubernetes.executor.eventProcessingInterval") + .doc("Interval between successive inspection of executor events sent from the" + + " Kubernetes API.") + .timeConf(TimeUnit.MILLISECONDS) + .checkValue(interval => interval > 0, s"Event processing interval must be a positive" + + " time value.") + .createWithDefaultString("1s") + val MEMORY_OVERHEAD_FACTOR = ConfigBuilder("spark.kubernetes.memoryOverheadFactor") .doc("This sets the Memory Overhead Factor that will allocate memory to non-JVM jobs " + @@ -193,7 +211,6 @@ private[spark] object Config extends Logging { "Ensure that major Python version is either Python2 or Python3") .createWithDefault("2") - val KUBERNETES_AUTH_SUBMISSION_CONF_PREFIX = "spark.kubernetes.authenticate.submission" diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodStates.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodStates.scala new file mode 100644 index 000000000000..83daddf71448 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodStates.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import io.fabric8.kubernetes.api.model.Pod + +sealed trait ExecutorPodState { + def pod: Pod +} + +case class PodRunning(pod: Pod) extends ExecutorPodState + +case class PodPending(pod: Pod) extends ExecutorPodState + +sealed trait FinalPodState extends ExecutorPodState + +case class PodSucceeded(pod: Pod) extends FinalPodState + +case class PodFailed(pod: Pod) extends FinalPodState + +case class PodDeleted(pod: Pod) extends FinalPodState + +case class PodUnknown(pod: Pod) extends ExecutorPodState diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocator.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocator.scala new file mode 100644 index 000000000000..5a143ad3600f --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocator.scala @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import java.util.concurrent.atomic.{AtomicInteger, AtomicLong} + +import io.fabric8.kubernetes.api.model.PodBuilder +import io.fabric8.kubernetes.client.KubernetesClient +import scala.collection.mutable + +import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.KubernetesConf +import org.apache.spark.internal.Logging +import org.apache.spark.util.{Clock, Utils} + +private[spark] class ExecutorPodsAllocator( + conf: SparkConf, + executorBuilder: KubernetesExecutorBuilder, + kubernetesClient: KubernetesClient, + snapshotsStore: ExecutorPodsSnapshotsStore, + clock: Clock) extends Logging { + + private val EXECUTOR_ID_COUNTER = new AtomicLong(0L) + + private val totalExpectedExecutors = new AtomicInteger(0) + + private val podAllocationSize = conf.get(KUBERNETES_ALLOCATION_BATCH_SIZE) + + private val podAllocationDelay = conf.get(KUBERNETES_ALLOCATION_BATCH_DELAY) + + private val podCreationTimeout = math.max(podAllocationDelay * 5, 60000) + + private val kubernetesDriverPodName = conf + .get(KUBERNETES_DRIVER_POD_NAME) + .getOrElse(throw new SparkException("Must specify the driver pod name")) + + private val driverPod = kubernetesClient.pods() + .withName(kubernetesDriverPodName) + .get() + + // Executor IDs that have been requested from Kubernetes but have not been detected in any + // snapshot yet. Mapped to the timestamp when they were created. + private val newlyCreatedExecutors = mutable.Map.empty[Long, Long] + + def start(applicationId: String): Unit = { + snapshotsStore.addSubscriber(podAllocationDelay) { + onNewSnapshots(applicationId, _) + } + } + + def setTotalExpectedExecutors(total: Int): Unit = totalExpectedExecutors.set(total) + + private def onNewSnapshots(applicationId: String, snapshots: Seq[ExecutorPodsSnapshot]): Unit = { + newlyCreatedExecutors --= snapshots.flatMap(_.executorPods.keys) + // For all executors we've created against the API but have not seen in a snapshot + // yet - check the current time. If the current time has exceeded some threshold, + // assume that the pod was either never created (the API server never properly + // handled the creation request), or the API server created the pod but we missed + // both the creation and deletion events. In either case, delete the missing pod + // if possible, and mark such a pod to be rescheduled below. + newlyCreatedExecutors.foreach { case (execId, timeCreated) => + val currentTime = clock.getTimeMillis() + if (currentTime - timeCreated > podCreationTimeout) { + logWarning(s"Executor with id $execId was not detected in the Kubernetes" + + s" cluster after $podCreationTimeout milliseconds despite the fact that a" + + " previous allocation attempt tried to create it. The executor may have been" + + " deleted but the application missed the deletion event.") + Utils.tryLogNonFatalError { + kubernetesClient + .pods() + .withLabel(SPARK_EXECUTOR_ID_LABEL, execId.toString) + .delete() + } + newlyCreatedExecutors -= execId + } else { + logDebug(s"Executor with id $execId was not found in the Kubernetes cluster since it" + + s" was created ${currentTime - timeCreated} milliseconds ago.") + } + } + + if (snapshots.nonEmpty) { + // Only need to examine the cluster as of the latest snapshot, the "current" state, to see if + // we need to allocate more executors or not. + val latestSnapshot = snapshots.last + val currentRunningExecutors = latestSnapshot.executorPods.values.count { + case PodRunning(_) => true + case _ => false + } + val currentPendingExecutors = latestSnapshot.executorPods.values.count { + case PodPending(_) => true + case _ => false + } + val currentTotalExpectedExecutors = totalExpectedExecutors.get + logDebug(s"Currently have $currentRunningExecutors running executors and" + + s" $currentPendingExecutors pending executors. $newlyCreatedExecutors executors" + + s" have been requested but are pending appearance in the cluster.") + if (newlyCreatedExecutors.isEmpty + && currentPendingExecutors == 0 + && currentRunningExecutors < currentTotalExpectedExecutors) { + val numExecutorsToAllocate = math.min( + currentTotalExpectedExecutors - currentRunningExecutors, podAllocationSize) + logInfo(s"Going to request $numExecutorsToAllocate executors from Kubernetes.") + for ( _ <- 0 until numExecutorsToAllocate) { + val newExecutorId = EXECUTOR_ID_COUNTER.incrementAndGet() + val executorConf = KubernetesConf.createExecutorConf( + conf, + newExecutorId.toString, + applicationId, + driverPod) + val executorPod = executorBuilder.buildFromFeatures(executorConf) + val podWithAttachedContainer = new PodBuilder(executorPod.pod) + .editOrNewSpec() + .addToContainers(executorPod.container) + .endSpec() + .build() + kubernetesClient.pods().create(podWithAttachedContainer) + newlyCreatedExecutors(newExecutorId) = clock.getTimeMillis() + logDebug(s"Requested executor with id $newExecutorId from Kubernetes.") + } + } else if (currentRunningExecutors >= currentTotalExpectedExecutors) { + // TODO handle edge cases if we end up with more running executors than expected. + logDebug("Current number of running executors is equal to the number of requested" + + " executors. Not scaling up further.") + } else if (newlyCreatedExecutors.nonEmpty || currentPendingExecutors != 0) { + logDebug(s"Still waiting for ${newlyCreatedExecutors.size + currentPendingExecutors}" + + s" executors to begin running before requesting for more executors. # of executors in" + + s" pending status in the cluster: $currentPendingExecutors. # of executors that we have" + + s" created but we have not observed as being present in the cluster yet:" + + s" ${newlyCreatedExecutors.size}.") + } + } + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManager.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManager.scala new file mode 100644 index 000000000000..b28d93990313 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManager.scala @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import com.google.common.cache.Cache +import io.fabric8.kubernetes.api.model.Pod +import io.fabric8.kubernetes.client.KubernetesClient +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import org.apache.spark.SparkConf +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.internal.Logging +import org.apache.spark.scheduler.ExecutorExited +import org.apache.spark.util.Utils + +private[spark] class ExecutorPodsLifecycleManager( + conf: SparkConf, + executorBuilder: KubernetesExecutorBuilder, + kubernetesClient: KubernetesClient, + snapshotsStore: ExecutorPodsSnapshotsStore, + // Use a best-effort to track which executors have been removed already. It's not generally + // job-breaking if we remove executors more than once but it's ideal if we make an attempt + // to avoid doing so. Expire cache entries so that this data structure doesn't grow beyond + // bounds. + removedExecutorsCache: Cache[java.lang.Long, java.lang.Long]) extends Logging { + + import ExecutorPodsLifecycleManager._ + + private val eventProcessingInterval = conf.get(KUBERNETES_EXECUTOR_EVENT_PROCESSING_INTERVAL) + + def start(schedulerBackend: KubernetesClusterSchedulerBackend): Unit = { + snapshotsStore.addSubscriber(eventProcessingInterval) { + onNewSnapshots(schedulerBackend, _) + } + } + + private def onNewSnapshots( + schedulerBackend: KubernetesClusterSchedulerBackend, + snapshots: Seq[ExecutorPodsSnapshot]): Unit = { + val execIdsRemovedInThisRound = mutable.HashSet.empty[Long] + snapshots.foreach { snapshot => + snapshot.executorPods.foreach { case (execId, state) => + state match { + case deleted@PodDeleted(_) => + logDebug(s"Snapshot reported deleted executor with id $execId," + + s" pod name ${state.pod.getMetadata.getName}") + removeExecutorFromSpark(schedulerBackend, deleted, execId) + execIdsRemovedInThisRound += execId + case failed@PodFailed(_) => + logDebug(s"Snapshot reported failed executor with id $execId," + + s" pod name ${state.pod.getMetadata.getName}") + onFinalNonDeletedState(failed, execId, schedulerBackend, execIdsRemovedInThisRound) + case succeeded@PodSucceeded(_) => + logDebug(s"Snapshot reported succeeded executor with id $execId," + + s" pod name ${state.pod.getMetadata.getName}. Note that succeeded executors are" + + s" unusual unless Spark specifically informed the executor to exit.") + onFinalNonDeletedState(succeeded, execId, schedulerBackend, execIdsRemovedInThisRound) + case _ => + } + } + } + + // Reconcile the case where Spark claims to know about an executor but the corresponding pod + // is missing from the cluster. This would occur if we miss a deletion event and the pod + // transitions immediately from running io absent. We only need to check against the latest + // snapshot for this, and we don't do this for executors in the deleted executors cache or + // that we just removed in this round. + if (snapshots.nonEmpty) { + val latestSnapshot = snapshots.last + (schedulerBackend.getExecutorIds().map(_.toLong).toSet + -- latestSnapshot.executorPods.keySet + -- execIdsRemovedInThisRound).foreach { missingExecutorId => + if (removedExecutorsCache.getIfPresent(missingExecutorId) == null) { + val exitReasonMessage = s"The executor with ID $missingExecutorId was not found in the" + + s" cluster but we didn't get a reason why. Marking the executor as failed. The" + + s" executor may have been deleted but the driver missed the deletion event." + logDebug(exitReasonMessage) + val exitReason = ExecutorExited( + UNKNOWN_EXIT_CODE, + exitCausedByApp = false, + exitReasonMessage) + schedulerBackend.doRemoveExecutor(missingExecutorId.toString, exitReason) + execIdsRemovedInThisRound += missingExecutorId + } + } + } + logDebug(s"Removed executors with ids ${execIdsRemovedInThisRound.mkString(",")}" + + s" from Spark that were either found to be deleted or non-existent in the cluster.") + } + + private def onFinalNonDeletedState( + podState: FinalPodState, + execId: Long, + schedulerBackend: KubernetesClusterSchedulerBackend, + execIdsRemovedInRound: mutable.Set[Long]): Unit = { + removeExecutorFromK8s(podState.pod) + removeExecutorFromSpark(schedulerBackend, podState, execId) + execIdsRemovedInRound += execId + } + + private def removeExecutorFromK8s(updatedPod: Pod): Unit = { + // If deletion failed on a previous try, we can try again if resync informs us the pod + // is still around. + // Delete as best attempt - duplicate deletes will throw an exception but the end state + // of getting rid of the pod is what matters. + Utils.tryLogNonFatalError { + kubernetesClient + .pods() + .withName(updatedPod.getMetadata.getName) + .delete() + } + } + + private def removeExecutorFromSpark( + schedulerBackend: KubernetesClusterSchedulerBackend, + podState: FinalPodState, + execId: Long): Unit = { + if (removedExecutorsCache.getIfPresent(execId) == null) { + removedExecutorsCache.put(execId, execId) + val exitReason = findExitReason(podState, execId) + schedulerBackend.doRemoveExecutor(execId.toString, exitReason) + } + } + + private def findExitReason(podState: FinalPodState, execId: Long): ExecutorExited = { + val exitCode = findExitCode(podState) + val (exitCausedByApp, exitMessage) = podState match { + case PodDeleted(_) => + (false, s"The executor with id $execId was deleted by a user or the framework.") + case _ => + val msg = exitReasonMessage(podState, execId, exitCode) + (true, msg) + } + ExecutorExited(exitCode, exitCausedByApp, exitMessage) + } + + private def exitReasonMessage(podState: FinalPodState, execId: Long, exitCode: Int) = { + val pod = podState.pod + s""" + |The executor with id $execId exited with exit code $exitCode. + |The API gave the following brief reason: ${pod.getStatus.getReason} + |The API gave the following message: ${pod.getStatus.getMessage} + |The API gave the following container statuses: + | + |${pod.getStatus.getContainerStatuses.asScala.map(_.toString).mkString("\n===\n")} + """.stripMargin + } + + private def findExitCode(podState: FinalPodState): Int = { + podState.pod.getStatus.getContainerStatuses.asScala.find { containerStatus => + containerStatus.getState.getTerminated != null + }.map { terminatedContainer => + terminatedContainer.getState.getTerminated.getExitCode.toInt + }.getOrElse(UNKNOWN_EXIT_CODE) + } +} + +private object ExecutorPodsLifecycleManager { + val UNKNOWN_EXIT_CODE = -1 +} + diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsPollingSnapshotSource.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsPollingSnapshotSource.scala new file mode 100644 index 000000000000..e77e604d00e0 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsPollingSnapshotSource.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import java.util.concurrent.{Future, ScheduledExecutorService, TimeUnit} + +import io.fabric8.kubernetes.client.KubernetesClient +import scala.collection.JavaConverters._ + +import org.apache.spark.SparkConf +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.internal.Logging +import org.apache.spark.util.ThreadUtils + +private[spark] class ExecutorPodsPollingSnapshotSource( + conf: SparkConf, + kubernetesClient: KubernetesClient, + snapshotsStore: ExecutorPodsSnapshotsStore, + pollingExecutor: ScheduledExecutorService) extends Logging { + + private val pollingInterval = conf.get(KUBERNETES_EXECUTOR_API_POLLING_INTERVAL) + + private var pollingFuture: Future[_] = _ + + def start(applicationId: String): Unit = { + require(pollingFuture == null, "Cannot start polling more than once.") + logDebug(s"Starting to check for executor pod state every $pollingInterval ms.") + pollingFuture = pollingExecutor.scheduleWithFixedDelay( + new PollRunnable(applicationId), pollingInterval, pollingInterval, TimeUnit.MILLISECONDS) + } + + def stop(): Unit = { + if (pollingFuture != null) { + pollingFuture.cancel(true) + pollingFuture = null + } + ThreadUtils.shutdown(pollingExecutor) + } + + private class PollRunnable(applicationId: String) extends Runnable { + override def run(): Unit = { + logDebug(s"Resynchronizing full executor pod state from Kubernetes.") + snapshotsStore.replaceSnapshot(kubernetesClient + .pods() + .withLabel(SPARK_APP_ID_LABEL, applicationId) + .withLabel(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE) + .list() + .getItems + .asScala) + } + } + +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshot.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshot.scala new file mode 100644 index 000000000000..26be91804341 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshot.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import io.fabric8.kubernetes.api.model.Pod + +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.internal.Logging + +/** + * An immutable view of the current executor pods that are running in the cluster. + */ +private[spark] case class ExecutorPodsSnapshot(executorPods: Map[Long, ExecutorPodState]) { + + import ExecutorPodsSnapshot._ + + def withUpdate(updatedPod: Pod): ExecutorPodsSnapshot = { + val newExecutorPods = executorPods ++ toStatesByExecutorId(Seq(updatedPod)) + new ExecutorPodsSnapshot(newExecutorPods) + } +} + +object ExecutorPodsSnapshot extends Logging { + + def apply(executorPods: Seq[Pod]): ExecutorPodsSnapshot = { + ExecutorPodsSnapshot(toStatesByExecutorId(executorPods)) + } + + def apply(): ExecutorPodsSnapshot = ExecutorPodsSnapshot(Map.empty[Long, ExecutorPodState]) + + private def toStatesByExecutorId(executorPods: Seq[Pod]): Map[Long, ExecutorPodState] = { + executorPods.map { pod => + (pod.getMetadata.getLabels.get(SPARK_EXECUTOR_ID_LABEL).toLong, toState(pod)) + }.toMap + } + + private def toState(pod: Pod): ExecutorPodState = { + if (isDeleted(pod)) { + PodDeleted(pod) + } else { + val phase = pod.getStatus.getPhase.toLowerCase + phase match { + case "pending" => + PodPending(pod) + case "running" => + PodRunning(pod) + case "failed" => + PodFailed(pod) + case "succeeded" => + PodSucceeded(pod) + case _ => + logWarning(s"Received unknown phase $phase for executor pod with name" + + s" ${pod.getMetadata.getName} in namespace ${pod.getMetadata.getNamespace}") + PodUnknown(pod) + } + } + } + + private def isDeleted(pod: Pod): Boolean = pod.getMetadata.getDeletionTimestamp != null +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotsStore.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotsStore.scala new file mode 100644 index 000000000000..dd264332cf9e --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotsStore.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import io.fabric8.kubernetes.api.model.Pod + +private[spark] trait ExecutorPodsSnapshotsStore { + + def addSubscriber + (processBatchIntervalMillis: Long) + (onNewSnapshots: Seq[ExecutorPodsSnapshot] => Unit) + + def stop(): Unit + + def updatePod(updatedPod: Pod): Unit + + def replaceSnapshot(newSnapshot: Seq[Pod]): Unit +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotsStoreImpl.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotsStoreImpl.scala new file mode 100644 index 000000000000..5583b4617eeb --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotsStoreImpl.scala @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import java.util.concurrent._ + +import io.fabric8.kubernetes.api.model.Pod +import javax.annotation.concurrent.GuardedBy +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import org.apache.spark.util.{ThreadUtils, Utils} + +/** + * Controls the propagation of the Spark application's executor pods state to subscribers that + * react to that state. + *
+ * Roughly follows a producer-consumer model. Producers report states of executor pods, and these + * states are then published to consumers that can perform any actions in response to these states. + *
+ * Producers push updates in one of two ways. An incremental update sent by updatePod() represents + * a known new state of a single executor pod. A full sync sent by replaceSnapshot() indicates that + * the passed pods are all of the most up to date states of all executor pods for the application. + * The combination of the states of all executor pods for the application is collectively known as + * a snapshot. The store keeps track of the most up to date snapshot, and applies updates to that + * most recent snapshot - either by incrementally updating the snapshot with a single new pod state, + * or by replacing the snapshot entirely on a full sync. + *
+ * Consumers, or subscribers, register that they want to be informed about all snapshots of the + * executor pods. Every time the store replaces its most up to date snapshot from either an + * incremental update or a full sync, the most recent snapshot after the update is posted to the + * subscriber's buffer. Subscribers receive blocks of snapshots produced by the producers in + * time-windowed chunks. Each subscriber can choose to receive their snapshot chunks at different + * time intervals. + */ +private[spark] class ExecutorPodsSnapshotsStoreImpl(subscribersExecutor: ScheduledExecutorService) + extends ExecutorPodsSnapshotsStore { + + private val SNAPSHOT_LOCK = new Object() + + private val subscribers = mutable.Buffer.empty[SnapshotsSubscriber] + private val pollingTasks = mutable.Buffer.empty[Future[_]] + + @GuardedBy("SNAPSHOT_LOCK") + private var currentSnapshot = ExecutorPodsSnapshot() + + override def addSubscriber( + processBatchIntervalMillis: Long) + (onNewSnapshots: Seq[ExecutorPodsSnapshot] => Unit): Unit = { + val newSubscriber = SnapshotsSubscriber( + new LinkedBlockingQueue[ExecutorPodsSnapshot](), onNewSnapshots) + SNAPSHOT_LOCK.synchronized { + newSubscriber.snapshotsBuffer.add(currentSnapshot) + } + subscribers += newSubscriber + pollingTasks += subscribersExecutor.scheduleWithFixedDelay( + toRunnable(() => callSubscriber(newSubscriber)), + 0L, + processBatchIntervalMillis, + TimeUnit.MILLISECONDS) + } + + override def stop(): Unit = { + pollingTasks.foreach(_.cancel(true)) + ThreadUtils.shutdown(subscribersExecutor) + } + + override def updatePod(updatedPod: Pod): Unit = SNAPSHOT_LOCK.synchronized { + currentSnapshot = currentSnapshot.withUpdate(updatedPod) + addCurrentSnapshotToSubscribers() + } + + override def replaceSnapshot(newSnapshot: Seq[Pod]): Unit = SNAPSHOT_LOCK.synchronized { + currentSnapshot = ExecutorPodsSnapshot(newSnapshot) + addCurrentSnapshotToSubscribers() + } + + private def addCurrentSnapshotToSubscribers(): Unit = { + subscribers.foreach { subscriber => + subscriber.snapshotsBuffer.add(currentSnapshot) + } + } + + private def callSubscriber(subscriber: SnapshotsSubscriber): Unit = { + Utils.tryLogNonFatalError { + val currentSnapshots = mutable.Buffer.empty[ExecutorPodsSnapshot].asJava + subscriber.snapshotsBuffer.drainTo(currentSnapshots) + subscriber.onNewSnapshots(currentSnapshots.asScala) + } + } + + private def toRunnable[T](runnable: () => Unit): Runnable = new Runnable { + override def run(): Unit = runnable() + } + + private case class SnapshotsSubscriber( + snapshotsBuffer: BlockingQueue[ExecutorPodsSnapshot], + onNewSnapshots: Seq[ExecutorPodsSnapshot] => Unit) +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsWatchSnapshotSource.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsWatchSnapshotSource.scala new file mode 100644 index 000000000000..a6749a644e00 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsWatchSnapshotSource.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import java.io.Closeable + +import io.fabric8.kubernetes.api.model.Pod +import io.fabric8.kubernetes.client.{KubernetesClient, KubernetesClientException, Watcher} +import io.fabric8.kubernetes.client.Watcher.Action + +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.internal.Logging +import org.apache.spark.util.Utils + +private[spark] class ExecutorPodsWatchSnapshotSource( + snapshotsStore: ExecutorPodsSnapshotsStore, + kubernetesClient: KubernetesClient) extends Logging { + + private var watchConnection: Closeable = _ + + def start(applicationId: String): Unit = { + require(watchConnection == null, "Cannot start the watcher twice.") + logDebug(s"Starting watch for pods with labels $SPARK_APP_ID_LABEL=$applicationId," + + s" $SPARK_ROLE_LABEL=$SPARK_POD_EXECUTOR_ROLE.") + watchConnection = kubernetesClient.pods() + .withLabel(SPARK_APP_ID_LABEL, applicationId) + .withLabel(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE) + .watch(new ExecutorPodsWatcher()) + } + + def stop(): Unit = { + if (watchConnection != null) { + Utils.tryLogNonFatalError { + watchConnection.close() + } + watchConnection = null + } + } + + private class ExecutorPodsWatcher extends Watcher[Pod] { + override def eventReceived(action: Action, pod: Pod): Unit = { + val podName = pod.getMetadata.getName + logDebug(s"Received executor pod update for pod named $podName, action $action") + snapshotsStore.updatePod(pod) + } + + override def onClose(e: KubernetesClientException): Unit = { + logWarning("Kubernetes client has been closed (this is expected if the application is" + + " shutting down.)", e) + } + } + +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala index 0ea80dfbc0d9..c6e931a38405 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala @@ -17,7 +17,9 @@ package org.apache.spark.scheduler.cluster.k8s import java.io.File +import java.util.concurrent.TimeUnit +import com.google.common.cache.CacheBuilder import io.fabric8.kubernetes.client.Config import org.apache.spark.{SparkContext, SparkException} @@ -26,7 +28,7 @@ import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.internal.Logging import org.apache.spark.scheduler.{ExternalClusterManager, SchedulerBackend, TaskScheduler, TaskSchedulerImpl} -import org.apache.spark.util.ThreadUtils +import org.apache.spark.util.{SystemClock, ThreadUtils} private[spark] class KubernetesClusterManager extends ExternalClusterManager with Logging { @@ -56,17 +58,45 @@ private[spark] class KubernetesClusterManager extends ExternalClusterManager wit Some(new File(Config.KUBERNETES_SERVICE_ACCOUNT_TOKEN_PATH)), Some(new File(Config.KUBERNETES_SERVICE_ACCOUNT_CA_CRT_PATH))) - val allocatorExecutor = ThreadUtils - .newDaemonSingleThreadScheduledExecutor("kubernetes-pod-allocator") val requestExecutorsService = ThreadUtils.newDaemonCachedThreadPool( "kubernetes-executor-requests") + + val subscribersExecutor = ThreadUtils + .newDaemonThreadPoolScheduledExecutor( + "kubernetes-executor-snapshots-subscribers", 2) + val snapshotsStore = new ExecutorPodsSnapshotsStoreImpl(subscribersExecutor) + val removedExecutorsCache = CacheBuilder.newBuilder() + .expireAfterWrite(3, TimeUnit.MINUTES) + .build[java.lang.Long, java.lang.Long]() + val executorPodsLifecycleEventHandler = new ExecutorPodsLifecycleManager( + sc.conf, + new KubernetesExecutorBuilder(), + kubernetesClient, + snapshotsStore, + removedExecutorsCache) + + val executorPodsAllocator = new ExecutorPodsAllocator( + sc.conf, new KubernetesExecutorBuilder(), kubernetesClient, snapshotsStore, new SystemClock()) + + val podsWatchEventSource = new ExecutorPodsWatchSnapshotSource( + snapshotsStore, + kubernetesClient) + + val eventsPollingExecutor = ThreadUtils.newDaemonSingleThreadScheduledExecutor( + "kubernetes-executor-pod-polling-sync") + val podsPollingEventSource = new ExecutorPodsPollingSnapshotSource( + sc.conf, kubernetesClient, snapshotsStore, eventsPollingExecutor) + new KubernetesClusterSchedulerBackend( scheduler.asInstanceOf[TaskSchedulerImpl], sc.env.rpcEnv, - new KubernetesExecutorBuilder, kubernetesClient, - allocatorExecutor, - requestExecutorsService) + requestExecutorsService, + snapshotsStore, + executorPodsAllocator, + executorPodsLifecycleEventHandler, + podsWatchEventSource, + podsPollingEventSource) } override def initialize(scheduler: TaskScheduler, backend: SchedulerBackend): Unit = { diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala index d86664c81071..fa6dc2c479bb 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala @@ -16,60 +16,32 @@ */ package org.apache.spark.scheduler.cluster.k8s -import java.io.Closeable -import java.net.InetAddress -import java.util.concurrent.{ConcurrentHashMap, ExecutorService, ScheduledExecutorService, TimeUnit} -import java.util.concurrent.atomic.{AtomicInteger, AtomicLong, AtomicReference} -import javax.annotation.concurrent.GuardedBy +import java.util.concurrent.ExecutorService -import io.fabric8.kubernetes.api.model._ -import io.fabric8.kubernetes.client.{KubernetesClient, KubernetesClientException, Watcher} -import io.fabric8.kubernetes.client.Watcher.Action -import scala.collection.JavaConverters._ -import scala.collection.mutable +import io.fabric8.kubernetes.client.KubernetesClient import scala.concurrent.{ExecutionContext, Future} -import org.apache.spark.SparkException -import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.KubernetesConf -import org.apache.spark.rpc.{RpcAddress, RpcEndpointAddress, RpcEnv} -import org.apache.spark.scheduler.{ExecutorExited, SlaveLost, TaskSchedulerImpl} +import org.apache.spark.rpc.{RpcAddress, RpcEnv} +import org.apache.spark.scheduler.{ExecutorLossReason, TaskSchedulerImpl} import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, SchedulerBackendUtils} -import org.apache.spark.util.Utils +import org.apache.spark.util.{ThreadUtils, Utils} private[spark] class KubernetesClusterSchedulerBackend( scheduler: TaskSchedulerImpl, rpcEnv: RpcEnv, - executorBuilder: KubernetesExecutorBuilder, kubernetesClient: KubernetesClient, - allocatorExecutor: ScheduledExecutorService, - requestExecutorsService: ExecutorService) + requestExecutorsService: ExecutorService, + snapshotsStore: ExecutorPodsSnapshotsStore, + podAllocator: ExecutorPodsAllocator, + lifecycleEventHandler: ExecutorPodsLifecycleManager, + watchEvents: ExecutorPodsWatchSnapshotSource, + pollEvents: ExecutorPodsPollingSnapshotSource) extends CoarseGrainedSchedulerBackend(scheduler, rpcEnv) { - import KubernetesClusterSchedulerBackend._ - - private val EXECUTOR_ID_COUNTER = new AtomicLong(0L) - private val RUNNING_EXECUTOR_PODS_LOCK = new Object - @GuardedBy("RUNNING_EXECUTOR_PODS_LOCK") - private val runningExecutorsToPods = new mutable.HashMap[String, Pod] - private val executorPodsByIPs = new ConcurrentHashMap[String, Pod]() - private val podsWithKnownExitReasons = new ConcurrentHashMap[String, ExecutorExited]() - private val disconnectedPodsByExecutorIdPendingRemoval = new ConcurrentHashMap[String, Pod]() - - private val kubernetesNamespace = conf.get(KUBERNETES_NAMESPACE) - - private val kubernetesDriverPodName = conf - .get(KUBERNETES_DRIVER_POD_NAME) - .getOrElse(throw new SparkException("Must specify the driver pod name")) private implicit val requestExecutorContext = ExecutionContext.fromExecutorService( requestExecutorsService) - private val driverPod = kubernetesClient.pods() - .inNamespace(kubernetesNamespace) - .withName(kubernetesDriverPodName) - .get() - protected override val minRegisteredRatio = if (conf.getOption("spark.scheduler.minRegisteredResourcesRatio").isEmpty) { 0.8 @@ -77,372 +49,93 @@ private[spark] class KubernetesClusterSchedulerBackend( super.minRegisteredRatio } - private val executorWatchResource = new AtomicReference[Closeable] - private val totalExpectedExecutors = new AtomicInteger(0) - - private val driverUrl = RpcEndpointAddress( - conf.get("spark.driver.host"), - conf.getInt("spark.driver.port", DEFAULT_DRIVER_PORT), - CoarseGrainedSchedulerBackend.ENDPOINT_NAME).toString - private val initialExecutors = SchedulerBackendUtils.getInitialTargetExecutorNumber(conf) - private val podAllocationInterval = conf.get(KUBERNETES_ALLOCATION_BATCH_DELAY) - - private val podAllocationSize = conf.get(KUBERNETES_ALLOCATION_BATCH_SIZE) - - private val executorLostReasonCheckMaxAttempts = conf.get( - KUBERNETES_EXECUTOR_LOST_REASON_CHECK_MAX_ATTEMPTS) - - private val allocatorRunnable = new Runnable { - - // Maintains a map of executor id to count of checks performed to learn the loss reason - // for an executor. - private val executorReasonCheckAttemptCounts = new mutable.HashMap[String, Int] - - override def run(): Unit = { - handleDisconnectedExecutors() - - val executorsToAllocate = mutable.Map[String, Pod]() - val currentTotalRegisteredExecutors = totalRegisteredExecutors.get - val currentTotalExpectedExecutors = totalExpectedExecutors.get - val currentNodeToLocalTaskCount = getNodesWithLocalTaskCounts() - RUNNING_EXECUTOR_PODS_LOCK.synchronized { - if (currentTotalRegisteredExecutors < runningExecutorsToPods.size) { - logDebug("Waiting for pending executors before scaling") - } else if (currentTotalExpectedExecutors <= runningExecutorsToPods.size) { - logDebug("Maximum allowed executor limit reached. Not scaling up further.") - } else { - for (_ <- 0 until math.min( - currentTotalExpectedExecutors - runningExecutorsToPods.size, podAllocationSize)) { - val executorId = EXECUTOR_ID_COUNTER.incrementAndGet().toString - val executorConf = KubernetesConf.createExecutorConf( - conf, - executorId, - applicationId(), - driverPod) - val executorPod = executorBuilder.buildFromFeatures(executorConf) - val podWithAttachedContainer = new PodBuilder(executorPod.pod) - .editOrNewSpec() - .addToContainers(executorPod.container) - .endSpec() - .build() - - executorsToAllocate(executorId) = podWithAttachedContainer - logInfo( - s"Requesting a new executor, total executors is now ${runningExecutorsToPods.size}") - } - } - } - - val allocatedExecutors = executorsToAllocate.mapValues { pod => - Utils.tryLog { - kubernetesClient.pods().create(pod) - } - } - - RUNNING_EXECUTOR_PODS_LOCK.synchronized { - allocatedExecutors.map { - case (executorId, attemptedAllocatedExecutor) => - attemptedAllocatedExecutor.map { successfullyAllocatedExecutor => - runningExecutorsToPods.put(executorId, successfullyAllocatedExecutor) - } - } - } - } - - def handleDisconnectedExecutors(): Unit = { - // For each disconnected executor, synchronize with the loss reasons that may have been found - // by the executor pod watcher. If the loss reason was discovered by the watcher, - // inform the parent class with removeExecutor. - disconnectedPodsByExecutorIdPendingRemoval.asScala.foreach { - case (executorId, executorPod) => - val knownExitReason = Option(podsWithKnownExitReasons.remove( - executorPod.getMetadata.getName)) - knownExitReason.fold { - removeExecutorOrIncrementLossReasonCheckCount(executorId) - } { executorExited => - logWarning(s"Removing executor $executorId with loss reason " + executorExited.message) - removeExecutor(executorId, executorExited) - // We don't delete the pod running the executor that has an exit condition caused by - // the application from the Kubernetes API server. This allows users to debug later on - // through commands such as "kubectl logs " and - // "kubectl describe pod ". Note that exited containers have terminated and - // therefore won't take CPU and memory resources. - // Otherwise, the executor pod is marked to be deleted from the API server. - if (executorExited.exitCausedByApp) { - logInfo(s"Executor $executorId exited because of the application.") - deleteExecutorFromDataStructures(executorId) - } else { - logInfo(s"Executor $executorId failed because of a framework error.") - deleteExecutorFromClusterAndDataStructures(executorId) - } - } - } - } - - def removeExecutorOrIncrementLossReasonCheckCount(executorId: String): Unit = { - val reasonCheckCount = executorReasonCheckAttemptCounts.getOrElse(executorId, 0) - if (reasonCheckCount >= executorLostReasonCheckMaxAttempts) { - removeExecutor(executorId, SlaveLost("Executor lost for unknown reasons.")) - deleteExecutorFromClusterAndDataStructures(executorId) - } else { - executorReasonCheckAttemptCounts.put(executorId, reasonCheckCount + 1) - } - } - - def deleteExecutorFromClusterAndDataStructures(executorId: String): Unit = { - deleteExecutorFromDataStructures(executorId).foreach { pod => - kubernetesClient.pods().delete(pod) - } - } - - def deleteExecutorFromDataStructures(executorId: String): Option[Pod] = { - disconnectedPodsByExecutorIdPendingRemoval.remove(executorId) - executorReasonCheckAttemptCounts -= executorId - podsWithKnownExitReasons.remove(executorId) - RUNNING_EXECUTOR_PODS_LOCK.synchronized { - runningExecutorsToPods.remove(executorId).orElse { - logWarning(s"Unable to remove pod for unknown executor $executorId") - None - } - } - } - } - - override def sufficientResourcesRegistered(): Boolean = { - totalRegisteredExecutors.get() >= initialExecutors * minRegisteredRatio + // Allow removeExecutor to be accessible by ExecutorPodsLifecycleEventHandler + private[k8s] def doRemoveExecutor(executorId: String, reason: ExecutorLossReason): Unit = { + removeExecutor(executorId, reason) } override def start(): Unit = { super.start() - executorWatchResource.set( - kubernetesClient - .pods() - .withLabel(SPARK_APP_ID_LABEL, applicationId()) - .watch(new ExecutorPodsWatcher())) - - allocatorExecutor.scheduleWithFixedDelay( - allocatorRunnable, 0L, podAllocationInterval, TimeUnit.MILLISECONDS) - if (!Utils.isDynamicAllocationEnabled(conf)) { - doRequestTotalExecutors(initialExecutors) + podAllocator.setTotalExpectedExecutors(initialExecutors) } + lifecycleEventHandler.start(this) + podAllocator.start(applicationId()) + watchEvents.start(applicationId()) + pollEvents.start(applicationId()) } override def stop(): Unit = { - // stop allocation of new resources and caches. - allocatorExecutor.shutdown() - allocatorExecutor.awaitTermination(30, TimeUnit.SECONDS) - - // send stop message to executors so they shut down cleanly super.stop() - try { - val resource = executorWatchResource.getAndSet(null) - if (resource != null) { - resource.close() - } - } catch { - case e: Throwable => logWarning("Failed to close the executor pod watcher", e) + Utils.tryLogNonFatalError { + snapshotsStore.stop() } - // then delete the executor pods Utils.tryLogNonFatalError { - deleteExecutorPodsOnStop() - executorPodsByIPs.clear() + watchEvents.stop() } + Utils.tryLogNonFatalError { - logInfo("Closing kubernetes client") - kubernetesClient.close() + pollEvents.stop() } - } - /** - * @return A map of K8s cluster nodes to the number of tasks that could benefit from data - * locality if an executor launches on the cluster node. - */ - private def getNodesWithLocalTaskCounts() : Map[String, Int] = { - val nodeToLocalTaskCount = synchronized { - mutable.Map[String, Int]() ++ hostToLocalTaskCount + Utils.tryLogNonFatalError { + kubernetesClient.pods() + .withLabel(SPARK_APP_ID_LABEL, applicationId()) + .withLabel(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE) + .delete() } - for (pod <- executorPodsByIPs.values().asScala) { - // Remove cluster nodes that are running our executors already. - // TODO: This prefers spreading out executors across nodes. In case users want - // consolidating executors on fewer nodes, introduce a flag. See the spark.deploy.spreadOut - // flag that Spark standalone has: https://spark.apache.org/docs/latest/spark-standalone.html - nodeToLocalTaskCount.remove(pod.getSpec.getNodeName).nonEmpty || - nodeToLocalTaskCount.remove(pod.getStatus.getHostIP).nonEmpty || - nodeToLocalTaskCount.remove( - InetAddress.getByName(pod.getStatus.getHostIP).getCanonicalHostName).nonEmpty + Utils.tryLogNonFatalError { + ThreadUtils.shutdown(requestExecutorsService) + } + + Utils.tryLogNonFatalError { + kubernetesClient.close() } - nodeToLocalTaskCount.toMap[String, Int] } override def doRequestTotalExecutors(requestedTotal: Int): Future[Boolean] = Future[Boolean] { - totalExpectedExecutors.set(requestedTotal) + // TODO when we support dynamic allocation, the pod allocator should be told to process the + // current snapshot in order to decrease/increase the number of executors accordingly. + podAllocator.setTotalExpectedExecutors(requestedTotal) true } - override def doKillExecutors(executorIds: Seq[String]): Future[Boolean] = Future[Boolean] { - val podsToDelete = RUNNING_EXECUTOR_PODS_LOCK.synchronized { - executorIds.flatMap { executorId => - runningExecutorsToPods.remove(executorId) match { - case Some(pod) => - disconnectedPodsByExecutorIdPendingRemoval.put(executorId, pod) - Some(pod) - - case None => - logWarning(s"Unable to remove pod for unknown executor $executorId") - None - } - } - } - - kubernetesClient.pods().delete(podsToDelete: _*) - true + override def sufficientResourcesRegistered(): Boolean = { + totalRegisteredExecutors.get() >= initialExecutors * minRegisteredRatio } - private def deleteExecutorPodsOnStop(): Unit = { - val executorPodsToDelete = RUNNING_EXECUTOR_PODS_LOCK.synchronized { - val runningExecutorPodsCopy = Seq(runningExecutorsToPods.values.toSeq: _*) - runningExecutorsToPods.clear() - runningExecutorPodsCopy - } - kubernetesClient.pods().delete(executorPodsToDelete: _*) + override def getExecutorIds(): Seq[String] = synchronized { + super.getExecutorIds() } - private class ExecutorPodsWatcher extends Watcher[Pod] { - - private val DEFAULT_CONTAINER_FAILURE_EXIT_STATUS = -1 - - override def eventReceived(action: Action, pod: Pod): Unit = { - val podName = pod.getMetadata.getName - val podIP = pod.getStatus.getPodIP - - action match { - case Action.MODIFIED if (pod.getStatus.getPhase == "Running" - && pod.getMetadata.getDeletionTimestamp == null) => - val clusterNodeName = pod.getSpec.getNodeName - logInfo(s"Executor pod $podName ready, launched at $clusterNodeName as IP $podIP.") - executorPodsByIPs.put(podIP, pod) - - case Action.DELETED | Action.ERROR => - val executorId = getExecutorId(pod) - logDebug(s"Executor pod $podName at IP $podIP was at $action.") - if (podIP != null) { - executorPodsByIPs.remove(podIP) - } - - val executorExitReason = if (action == Action.ERROR) { - logWarning(s"Received error event of executor pod $podName. Reason: " + - pod.getStatus.getReason) - executorExitReasonOnError(pod) - } else if (action == Action.DELETED) { - logWarning(s"Received delete event of executor pod $podName. Reason: " + - pod.getStatus.getReason) - executorExitReasonOnDelete(pod) - } else { - throw new IllegalStateException( - s"Unknown action that should only be DELETED or ERROR: $action") - } - podsWithKnownExitReasons.put(pod.getMetadata.getName, executorExitReason) - - if (!disconnectedPodsByExecutorIdPendingRemoval.containsKey(executorId)) { - log.warn(s"Executor with id $executorId was not marked as disconnected, but the " + - s"watch received an event of type $action for this executor. The executor may " + - "have failed to start in the first place and never registered with the driver.") - } - disconnectedPodsByExecutorIdPendingRemoval.put(executorId, pod) - - case _ => logDebug(s"Received event of executor pod $podName: " + action) - } - } - - override def onClose(cause: KubernetesClientException): Unit = { - logDebug("Executor pod watch closed.", cause) - } - - private def getExecutorExitStatus(pod: Pod): Int = { - val containerStatuses = pod.getStatus.getContainerStatuses - if (!containerStatuses.isEmpty) { - // we assume the first container represents the pod status. This assumption may not hold - // true in the future. Revisit this if side-car containers start running inside executor - // pods. - getExecutorExitStatus(containerStatuses.get(0)) - } else DEFAULT_CONTAINER_FAILURE_EXIT_STATUS - } - - private def getExecutorExitStatus(containerStatus: ContainerStatus): Int = { - Option(containerStatus.getState).map { containerState => - Option(containerState.getTerminated).map { containerStateTerminated => - containerStateTerminated.getExitCode.intValue() - }.getOrElse(UNKNOWN_EXIT_CODE) - }.getOrElse(UNKNOWN_EXIT_CODE) - } - - private def isPodAlreadyReleased(pod: Pod): Boolean = { - val executorId = pod.getMetadata.getLabels.get(SPARK_EXECUTOR_ID_LABEL) - RUNNING_EXECUTOR_PODS_LOCK.synchronized { - !runningExecutorsToPods.contains(executorId) - } - } - - private def executorExitReasonOnError(pod: Pod): ExecutorExited = { - val containerExitStatus = getExecutorExitStatus(pod) - // container was probably actively killed by the driver. - if (isPodAlreadyReleased(pod)) { - ExecutorExited(containerExitStatus, exitCausedByApp = false, - s"Container in pod ${pod.getMetadata.getName} exited from explicit termination " + - "request.") - } else { - val containerExitReason = s"Pod ${pod.getMetadata.getName}'s executor container " + - s"exited with exit status code $containerExitStatus." - ExecutorExited(containerExitStatus, exitCausedByApp = true, containerExitReason) - } - } - - private def executorExitReasonOnDelete(pod: Pod): ExecutorExited = { - val exitMessage = if (isPodAlreadyReleased(pod)) { - s"Container in pod ${pod.getMetadata.getName} exited from explicit termination request." - } else { - s"Pod ${pod.getMetadata.getName} deleted or lost." - } - ExecutorExited(getExecutorExitStatus(pod), exitCausedByApp = false, exitMessage) - } - - private def getExecutorId(pod: Pod): String = { - val executorId = pod.getMetadata.getLabels.get(SPARK_EXECUTOR_ID_LABEL) - require(executorId != null, "Unexpected pod metadata; expected all executor pods " + - s"to have label $SPARK_EXECUTOR_ID_LABEL.") - executorId - } + override def doKillExecutors(executorIds: Seq[String]): Future[Boolean] = Future[Boolean] { + kubernetesClient.pods() + .withLabel(SPARK_APP_ID_LABEL, applicationId()) + .withLabel(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE) + .withLabelIn(SPARK_EXECUTOR_ID_LABEL, executorIds: _*) + .delete() + // Don't do anything else - let event handling from the Kubernetes API do the Spark changes } override def createDriverEndpoint(properties: Seq[(String, String)]): DriverEndpoint = { new KubernetesDriverEndpoint(rpcEnv, properties) } - private class KubernetesDriverEndpoint( - rpcEnv: RpcEnv, - sparkProperties: Seq[(String, String)]) + private class KubernetesDriverEndpoint(rpcEnv: RpcEnv, sparkProperties: Seq[(String, String)]) extends DriverEndpoint(rpcEnv, sparkProperties) { override def onDisconnected(rpcAddress: RpcAddress): Unit = { - addressToExecutorId.get(rpcAddress).foreach { executorId => - if (disableExecutor(executorId)) { - RUNNING_EXECUTOR_PODS_LOCK.synchronized { - runningExecutorsToPods.get(executorId).foreach { pod => - disconnectedPodsByExecutorIdPendingRemoval.put(executorId, pod) - } - } - } - } + // Don't do anything besides disabling the executor - allow the Kubernetes API events to + // drive the rest of the lifecycle decisions + // TODO what if we disconnect from a networking issue? Probably want to mark the executor + // to be deleted eventually. + addressToExecutorId.get(rpcAddress).foreach(disableExecutor) } } -} -private object KubernetesClusterSchedulerBackend { - private val UNKNOWN_EXIT_CODE = -1 } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/Fabric8Aliases.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/Fabric8Aliases.scala new file mode 100644 index 000000000000..527fc6b0d8f8 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/Fabric8Aliases.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s + +import io.fabric8.kubernetes.api.model.{DoneablePod, HasMetadata, Pod, PodList} +import io.fabric8.kubernetes.client.{Watch, Watcher} +import io.fabric8.kubernetes.client.dsl.{FilterWatchListDeletable, MixedOperation, NamespaceListVisitFromServerGetDeleteRecreateWaitApplicable, PodResource} + +object Fabric8Aliases { + type PODS = MixedOperation[Pod, PodList, DoneablePod, PodResource[Pod, DoneablePod]] + type LABELED_PODS = FilterWatchListDeletable[ + Pod, PodList, java.lang.Boolean, Watch, Watcher[Pod]] + type SINGLE_POD = PodResource[Pod, DoneablePod] + type RESOURCE_LIST = NamespaceListVisitFromServerGetDeleteRecreateWaitApplicable[ + HasMetadata, Boolean] +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala index a8a8218c621e..d045d9ae89c0 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala @@ -27,6 +27,7 @@ import org.scalatest.mockito.MockitoSugar._ import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpec, KubernetesDriverSpecificConf, SparkPod} import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.Fabric8Aliases._ class ClientSuite extends SparkFunSuite with BeforeAndAfter { @@ -103,15 +104,11 @@ class ClientSuite extends SparkFunSuite with BeforeAndAfter { .build() } - private type ResourceList = NamespaceListVisitFromServerGetDeleteRecreateWaitApplicable[ - HasMetadata, Boolean] - private type Pods = MixedOperation[Pod, PodList, DoneablePod, PodResource[Pod, DoneablePod]] - @Mock private var kubernetesClient: KubernetesClient = _ @Mock - private var podOperations: Pods = _ + private var podOperations: PODS = _ @Mock private var namedPods: PodResource[Pod, DoneablePod] = _ @@ -123,7 +120,7 @@ class ClientSuite extends SparkFunSuite with BeforeAndAfter { private var driverBuilder: KubernetesDriverBuilder = _ @Mock - private var resourceList: ResourceList = _ + private var resourceList: RESOURCE_LIST = _ private var kubernetesConf: KubernetesConf[KubernetesDriverSpecificConf] = _ diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/DeterministicExecutorPodsSnapshotsStore.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/DeterministicExecutorPodsSnapshotsStore.scala new file mode 100644 index 000000000000..f7721e6fd638 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/DeterministicExecutorPodsSnapshotsStore.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import io.fabric8.kubernetes.api.model.Pod +import scala.collection.mutable + +class DeterministicExecutorPodsSnapshotsStore extends ExecutorPodsSnapshotsStore { + + private val snapshotsBuffer = mutable.Buffer.empty[ExecutorPodsSnapshot] + private val subscribers = mutable.Buffer.empty[Seq[ExecutorPodsSnapshot] => Unit] + + private var currentSnapshot = ExecutorPodsSnapshot() + + override def addSubscriber + (processBatchIntervalMillis: Long) + (onNewSnapshots: Seq[ExecutorPodsSnapshot] => Unit): Unit = { + subscribers += onNewSnapshots + } + + override def stop(): Unit = {} + + def notifySubscribers(): Unit = { + subscribers.foreach(_(snapshotsBuffer)) + snapshotsBuffer.clear() + } + + override def updatePod(updatedPod: Pod): Unit = { + currentSnapshot = currentSnapshot.withUpdate(updatedPod) + snapshotsBuffer += currentSnapshot + } + + override def replaceSnapshot(newSnapshot: Seq[Pod]): Unit = { + currentSnapshot = ExecutorPodsSnapshot(newSnapshot) + snapshotsBuffer += currentSnapshot + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorLifecycleTestUtils.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorLifecycleTestUtils.scala new file mode 100644 index 000000000000..c6b667ed85e8 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorLifecycleTestUtils.scala @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import io.fabric8.kubernetes.api.model.{ContainerBuilder, Pod, PodBuilder} + +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.SparkPod + +object ExecutorLifecycleTestUtils { + + val TEST_SPARK_APP_ID = "spark-app-id" + + def failedExecutorWithoutDeletion(executorId: Long): Pod = { + new PodBuilder(podWithAttachedContainerForId(executorId)) + .editOrNewStatus() + .withPhase("failed") + .addNewContainerStatus() + .withName("spark-executor") + .withImage("k8s-spark") + .withNewState() + .withNewTerminated() + .withMessage("Failed") + .withExitCode(1) + .endTerminated() + .endState() + .endContainerStatus() + .addNewContainerStatus() + .withName("spark-executor-sidecar") + .withImage("k8s-spark-sidecar") + .withNewState() + .withNewTerminated() + .withMessage("Failed") + .withExitCode(1) + .endTerminated() + .endState() + .endContainerStatus() + .withMessage("Executor failed.") + .withReason("Executor failed because of a thrown error.") + .endStatus() + .build() + } + + def pendingExecutor(executorId: Long): Pod = { + new PodBuilder(podWithAttachedContainerForId(executorId)) + .editOrNewStatus() + .withPhase("pending") + .endStatus() + .build() + } + + def runningExecutor(executorId: Long): Pod = { + new PodBuilder(podWithAttachedContainerForId(executorId)) + .editOrNewStatus() + .withPhase("running") + .endStatus() + .build() + } + + def succeededExecutor(executorId: Long): Pod = { + new PodBuilder(podWithAttachedContainerForId(executorId)) + .editOrNewStatus() + .withPhase("succeeded") + .endStatus() + .build() + } + + def deletedExecutor(executorId: Long): Pod = { + new PodBuilder(podWithAttachedContainerForId(executorId)) + .editOrNewMetadata() + .withNewDeletionTimestamp("523012521") + .endMetadata() + .build() + } + + def unknownExecutor(executorId: Long): Pod = { + new PodBuilder(podWithAttachedContainerForId(executorId)) + .editOrNewStatus() + .withPhase("unknown") + .endStatus() + .build() + } + + def podWithAttachedContainerForId(executorId: Long): Pod = { + val sparkPod = executorPodWithId(executorId) + val podWithAttachedContainer = new PodBuilder(sparkPod.pod) + .editOrNewSpec() + .addToContainers(sparkPod.container) + .endSpec() + .build() + podWithAttachedContainer + } + + def executorPodWithId(executorId: Long): SparkPod = { + val pod = new PodBuilder() + .withNewMetadata() + .withName(s"spark-executor-$executorId") + .addToLabels(SPARK_APP_ID_LABEL, TEST_SPARK_APP_ID) + .addToLabels(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE) + .addToLabels(SPARK_EXECUTOR_ID_LABEL, executorId.toString) + .endMetadata() + .build() + val container = new ContainerBuilder() + .withName("spark-executor") + .withImage("k8s-spark") + .build() + SparkPod(pod, container) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala new file mode 100644 index 000000000000..0c19f5946b75 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import io.fabric8.kubernetes.api.model.{DoneablePod, Pod, PodBuilder} +import io.fabric8.kubernetes.client.KubernetesClient +import io.fabric8.kubernetes.client.dsl.PodResource +import org.mockito.{ArgumentMatcher, Matchers, Mock, MockitoAnnotations} +import org.mockito.Matchers.any +import org.mockito.Mockito.{never, times, verify, when} +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer +import org.scalatest.BeforeAndAfter + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.Fabric8Aliases._ +import org.apache.spark.scheduler.cluster.k8s.ExecutorLifecycleTestUtils._ +import org.apache.spark.util.ManualClock + +class ExecutorPodsAllocatorSuite extends SparkFunSuite with BeforeAndAfter { + + private val driverPodName = "driver" + + private val driverPod = new PodBuilder() + .withNewMetadata() + .withName(driverPodName) + .addToLabels(SPARK_APP_ID_LABEL, TEST_SPARK_APP_ID) + .addToLabels(SPARK_ROLE_LABEL, SPARK_POD_DRIVER_ROLE) + .withUid("driver-pod-uid") + .endMetadata() + .build() + + private val conf = new SparkConf().set(KUBERNETES_DRIVER_POD_NAME, driverPodName) + + private val podAllocationSize = conf.get(KUBERNETES_ALLOCATION_BATCH_SIZE) + private val podAllocationDelay = conf.get(KUBERNETES_ALLOCATION_BATCH_DELAY) + private val podCreationTimeout = math.max(podAllocationDelay * 5, 60000L) + + private var waitForExecutorPodsClock: ManualClock = _ + + @Mock + private var kubernetesClient: KubernetesClient = _ + + @Mock + private var podOperations: PODS = _ + + @Mock + private var labeledPods: LABELED_PODS = _ + + @Mock + private var driverPodOperations: PodResource[Pod, DoneablePod] = _ + + @Mock + private var executorBuilder: KubernetesExecutorBuilder = _ + + private var snapshotsStore: DeterministicExecutorPodsSnapshotsStore = _ + + private var podsAllocatorUnderTest: ExecutorPodsAllocator = _ + + before { + MockitoAnnotations.initMocks(this) + when(kubernetesClient.pods()).thenReturn(podOperations) + when(podOperations.withName(driverPodName)).thenReturn(driverPodOperations) + when(driverPodOperations.get).thenReturn(driverPod) + when(executorBuilder.buildFromFeatures(kubernetesConfWithCorrectFields())) + .thenAnswer(executorPodAnswer()) + snapshotsStore = new DeterministicExecutorPodsSnapshotsStore() + waitForExecutorPodsClock = new ManualClock(0L) + podsAllocatorUnderTest = new ExecutorPodsAllocator( + conf, executorBuilder, kubernetesClient, snapshotsStore, waitForExecutorPodsClock) + podsAllocatorUnderTest.start(TEST_SPARK_APP_ID) + } + + test("Initially request executors in batches. Do not request another batch if the" + + " first has not finished.") { + podsAllocatorUnderTest.setTotalExpectedExecutors(podAllocationSize + 1) + snapshotsStore.replaceSnapshot(Seq.empty[Pod]) + snapshotsStore.notifySubscribers() + for (nextId <- 1 to podAllocationSize) { + verify(podOperations).create(podWithAttachedContainerForId(nextId)) + } + verify(podOperations, never()).create(podWithAttachedContainerForId(podAllocationSize + 1)) + } + + test("Request executors in batches. Allow another batch to be requested if" + + " all pending executors start running.") { + podsAllocatorUnderTest.setTotalExpectedExecutors(podAllocationSize + 1) + snapshotsStore.replaceSnapshot(Seq.empty[Pod]) + snapshotsStore.notifySubscribers() + for (execId <- 1 until podAllocationSize) { + snapshotsStore.updatePod(runningExecutor(execId)) + } + snapshotsStore.notifySubscribers() + verify(podOperations, never()).create(podWithAttachedContainerForId(podAllocationSize + 1)) + snapshotsStore.updatePod(runningExecutor(podAllocationSize)) + snapshotsStore.notifySubscribers() + verify(podOperations).create(podWithAttachedContainerForId(podAllocationSize + 1)) + snapshotsStore.updatePod(runningExecutor(podAllocationSize)) + snapshotsStore.notifySubscribers() + verify(podOperations, times(podAllocationSize + 1)).create(any(classOf[Pod])) + } + + test("When a current batch reaches error states immediately, re-request" + + " them on the next batch.") { + podsAllocatorUnderTest.setTotalExpectedExecutors(podAllocationSize) + snapshotsStore.replaceSnapshot(Seq.empty[Pod]) + snapshotsStore.notifySubscribers() + for (execId <- 1 until podAllocationSize) { + snapshotsStore.updatePod(runningExecutor(execId)) + } + val failedPod = failedExecutorWithoutDeletion(podAllocationSize) + snapshotsStore.updatePod(failedPod) + snapshotsStore.notifySubscribers() + verify(podOperations).create(podWithAttachedContainerForId(podAllocationSize + 1)) + } + + test("When an executor is requested but the API does not report it in a reasonable time, retry" + + " requesting that executor.") { + podsAllocatorUnderTest.setTotalExpectedExecutors(1) + snapshotsStore.replaceSnapshot(Seq.empty[Pod]) + snapshotsStore.notifySubscribers() + snapshotsStore.replaceSnapshot(Seq.empty[Pod]) + waitForExecutorPodsClock.setTime(podCreationTimeout + 1) + when(podOperations.withLabel(SPARK_EXECUTOR_ID_LABEL, "1")).thenReturn(labeledPods) + snapshotsStore.notifySubscribers() + verify(labeledPods).delete() + verify(podOperations).create(podWithAttachedContainerForId(2)) + } + + private def executorPodAnswer(): Answer[SparkPod] = { + new Answer[SparkPod] { + override def answer(invocation: InvocationOnMock): SparkPod = { + val k8sConf = invocation.getArgumentAt( + 0, classOf[KubernetesConf[KubernetesExecutorSpecificConf]]) + executorPodWithId(k8sConf.roleSpecificConf.executorId.toInt) + } + } + } + + private def kubernetesConfWithCorrectFields(): KubernetesConf[KubernetesExecutorSpecificConf] = + Matchers.argThat(new ArgumentMatcher[KubernetesConf[KubernetesExecutorSpecificConf]] { + override def matches(argument: scala.Any): Boolean = { + if (!argument.isInstanceOf[KubernetesConf[KubernetesExecutorSpecificConf]]) { + false + } else { + val k8sConf = argument.asInstanceOf[KubernetesConf[KubernetesExecutorSpecificConf]] + val executorSpecificConf = k8sConf.roleSpecificConf + val expectedK8sConf = KubernetesConf.createExecutorConf( + conf, + executorSpecificConf.executorId, + TEST_SPARK_APP_ID, + driverPod) + k8sConf.sparkConf.getAll.toMap == conf.getAll.toMap && + // Since KubernetesConf.createExecutorConf clones the SparkConf object, force + // deep equality comparison for the SparkConf object and use object equality + // comparison on all other fields. + k8sConf.copy(sparkConf = conf) == expectedK8sConf.copy(sparkConf = conf) + } + } + }) + +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManagerSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManagerSuite.scala new file mode 100644 index 000000000000..562ace9f49d4 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManagerSuite.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import com.google.common.cache.CacheBuilder +import io.fabric8.kubernetes.api.model.{DoneablePod, Pod} +import io.fabric8.kubernetes.client.KubernetesClient +import io.fabric8.kubernetes.client.dsl.PodResource +import org.mockito.{Mock, MockitoAnnotations} +import org.mockito.Matchers.any +import org.mockito.Mockito.{mock, times, verify, when} +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer +import org.scalatest.BeforeAndAfter +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.Fabric8Aliases._ +import org.apache.spark.scheduler.ExecutorExited +import org.apache.spark.scheduler.cluster.k8s.ExecutorLifecycleTestUtils._ + +class ExecutorPodsLifecycleManagerSuite extends SparkFunSuite with BeforeAndAfter { + + private var namedExecutorPods: mutable.Map[String, PodResource[Pod, DoneablePod]] = _ + + @Mock + private var kubernetesClient: KubernetesClient = _ + + @Mock + private var podOperations: PODS = _ + + @Mock + private var executorBuilder: KubernetesExecutorBuilder = _ + + @Mock + private var schedulerBackend: KubernetesClusterSchedulerBackend = _ + + private var snapshotsStore: DeterministicExecutorPodsSnapshotsStore = _ + private var eventHandlerUnderTest: ExecutorPodsLifecycleManager = _ + + before { + MockitoAnnotations.initMocks(this) + val removedExecutorsCache = CacheBuilder.newBuilder().build[java.lang.Long, java.lang.Long] + snapshotsStore = new DeterministicExecutorPodsSnapshotsStore() + namedExecutorPods = mutable.Map.empty[String, PodResource[Pod, DoneablePod]] + when(schedulerBackend.getExecutorIds()).thenReturn(Seq.empty[String]) + when(kubernetesClient.pods()).thenReturn(podOperations) + when(podOperations.withName(any(classOf[String]))).thenAnswer(namedPodsAnswer()) + eventHandlerUnderTest = new ExecutorPodsLifecycleManager( + new SparkConf(), + executorBuilder, + kubernetesClient, + snapshotsStore, + removedExecutorsCache) + eventHandlerUnderTest.start(schedulerBackend) + } + + test("When an executor reaches error states immediately, remove from the scheduler backend.") { + val failedPod = failedExecutorWithoutDeletion(1) + snapshotsStore.updatePod(failedPod) + snapshotsStore.notifySubscribers() + val msg = exitReasonMessage(1, failedPod) + val expectedLossReason = ExecutorExited(1, exitCausedByApp = true, msg) + verify(schedulerBackend).doRemoveExecutor("1", expectedLossReason) + verify(namedExecutorPods(failedPod.getMetadata.getName)).delete() + } + + test("Don't remove executors twice from Spark but remove from K8s repeatedly.") { + val failedPod = failedExecutorWithoutDeletion(1) + snapshotsStore.updatePod(failedPod) + snapshotsStore.updatePod(failedPod) + snapshotsStore.notifySubscribers() + val msg = exitReasonMessage(1, failedPod) + val expectedLossReason = ExecutorExited(1, exitCausedByApp = true, msg) + verify(schedulerBackend, times(1)).doRemoveExecutor("1", expectedLossReason) + verify(namedExecutorPods(failedPod.getMetadata.getName), times(2)).delete() + } + + test("When the scheduler backend lists executor ids that aren't present in the cluster," + + " remove those executors from Spark.") { + when(schedulerBackend.getExecutorIds()).thenReturn(Seq("1")) + val msg = s"The executor with ID 1 was not found in the cluster but we didn't" + + s" get a reason why. Marking the executor as failed. The executor may have been" + + s" deleted but the driver missed the deletion event." + val expectedLossReason = ExecutorExited(-1, exitCausedByApp = false, msg) + snapshotsStore.replaceSnapshot(Seq.empty[Pod]) + snapshotsStore.notifySubscribers() + verify(schedulerBackend).doRemoveExecutor("1", expectedLossReason) + } + + private def exitReasonMessage(failedExecutorId: Int, failedPod: Pod): String = { + s""" + |The executor with id $failedExecutorId exited with exit code 1. + |The API gave the following brief reason: ${failedPod.getStatus.getReason} + |The API gave the following message: ${failedPod.getStatus.getMessage} + |The API gave the following container statuses: + | + |${failedPod.getStatus.getContainerStatuses.asScala.map(_.toString).mkString("\n===\n")} + """.stripMargin + } + + private def namedPodsAnswer(): Answer[PodResource[Pod, DoneablePod]] = { + new Answer[PodResource[Pod, DoneablePod]] { + override def answer(invocation: InvocationOnMock): PodResource[Pod, DoneablePod] = { + val podName = invocation.getArgumentAt(0, classOf[String]) + namedExecutorPods.getOrElseUpdate( + podName, mock(classOf[PodResource[Pod, DoneablePod]])) + } + } + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsPollingSnapshotSourceSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsPollingSnapshotSourceSuite.scala new file mode 100644 index 000000000000..1b26d6af296a --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsPollingSnapshotSourceSuite.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import java.util.concurrent.TimeUnit + +import io.fabric8.kubernetes.api.model.PodListBuilder +import io.fabric8.kubernetes.client.KubernetesClient +import org.jmock.lib.concurrent.DeterministicScheduler +import org.mockito.{Mock, MockitoAnnotations} +import org.mockito.Mockito.{verify, when} +import org.scalatest.BeforeAndAfter + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.Fabric8Aliases._ +import org.apache.spark.scheduler.cluster.k8s.ExecutorLifecycleTestUtils._ + +class ExecutorPodsPollingSnapshotSourceSuite extends SparkFunSuite with BeforeAndAfter { + + private val sparkConf = new SparkConf + + private val pollingInterval = sparkConf.get(KUBERNETES_EXECUTOR_API_POLLING_INTERVAL) + + @Mock + private var kubernetesClient: KubernetesClient = _ + + @Mock + private var podOperations: PODS = _ + + @Mock + private var appIdLabeledPods: LABELED_PODS = _ + + @Mock + private var executorRoleLabeledPods: LABELED_PODS = _ + + @Mock + private var eventQueue: ExecutorPodsSnapshotsStore = _ + + private var pollingExecutor: DeterministicScheduler = _ + private var pollingSourceUnderTest: ExecutorPodsPollingSnapshotSource = _ + + before { + MockitoAnnotations.initMocks(this) + pollingExecutor = new DeterministicScheduler() + pollingSourceUnderTest = new ExecutorPodsPollingSnapshotSource( + sparkConf, + kubernetesClient, + eventQueue, + pollingExecutor) + pollingSourceUnderTest.start(TEST_SPARK_APP_ID) + when(kubernetesClient.pods()).thenReturn(podOperations) + when(podOperations.withLabel(SPARK_APP_ID_LABEL, TEST_SPARK_APP_ID)) + .thenReturn(appIdLabeledPods) + when(appIdLabeledPods.withLabel(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE)) + .thenReturn(executorRoleLabeledPods) + } + + test("Items returned by the API should be pushed to the event queue") { + when(executorRoleLabeledPods.list()) + .thenReturn(new PodListBuilder() + .addToItems( + runningExecutor(1), + runningExecutor(2)) + .build()) + pollingExecutor.tick(pollingInterval, TimeUnit.MILLISECONDS) + verify(eventQueue).replaceSnapshot(Seq(runningExecutor(1), runningExecutor(2))) + + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotSuite.scala new file mode 100644 index 000000000000..70e19c904edd --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotSuite.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import org.apache.spark.SparkFunSuite +import org.apache.spark.scheduler.cluster.k8s.ExecutorLifecycleTestUtils._ + +class ExecutorPodsSnapshotSuite extends SparkFunSuite { + + test("States are interpreted correctly from pod metadata.") { + val pods = Seq( + pendingExecutor(0), + runningExecutor(1), + succeededExecutor(2), + failedExecutorWithoutDeletion(3), + deletedExecutor(4), + unknownExecutor(5)) + val snapshot = ExecutorPodsSnapshot(pods) + assert(snapshot.executorPods === + Map( + 0L -> PodPending(pods(0)), + 1L -> PodRunning(pods(1)), + 2L -> PodSucceeded(pods(2)), + 3L -> PodFailed(pods(3)), + 4L -> PodDeleted(pods(4)), + 5L -> PodUnknown(pods(5)))) + } + + test("Updates add new pods for non-matching ids and edit existing pods for matching ids") { + val originalPods = Seq( + pendingExecutor(0), + runningExecutor(1)) + val originalSnapshot = ExecutorPodsSnapshot(originalPods) + val snapshotWithUpdatedPod = originalSnapshot.withUpdate(succeededExecutor(1)) + assert(snapshotWithUpdatedPod.executorPods === + Map( + 0L -> PodPending(originalPods(0)), + 1L -> PodSucceeded(succeededExecutor(1)))) + val snapshotWithNewPod = snapshotWithUpdatedPod.withUpdate(pendingExecutor(2)) + assert(snapshotWithNewPod.executorPods === + Map( + 0L -> PodPending(originalPods(0)), + 1L -> PodSucceeded(succeededExecutor(1)), + 2L -> PodPending(pendingExecutor(2)))) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotsStoreSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotsStoreSuite.scala new file mode 100644 index 000000000000..cf54b3c4eb32 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotsStoreSuite.scala @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import java.util.concurrent.TimeUnit +import java.util.concurrent.atomic.AtomicReference + +import io.fabric8.kubernetes.api.model.{Pod, PodBuilder} +import org.jmock.lib.concurrent.DeterministicScheduler +import org.scalatest.BeforeAndAfter +import scala.collection.mutable + +import org.apache.spark.SparkFunSuite +import org.apache.spark.deploy.k8s.Constants._ + +class ExecutorPodsSnapshotsStoreSuite extends SparkFunSuite with BeforeAndAfter { + + private var eventBufferScheduler: DeterministicScheduler = _ + private var eventQueueUnderTest: ExecutorPodsSnapshotsStoreImpl = _ + + before { + eventBufferScheduler = new DeterministicScheduler() + eventQueueUnderTest = new ExecutorPodsSnapshotsStoreImpl(eventBufferScheduler) + } + + test("Subscribers get notified of events periodically.") { + val receivedSnapshots1 = mutable.Buffer.empty[ExecutorPodsSnapshot] + val receivedSnapshots2 = mutable.Buffer.empty[ExecutorPodsSnapshot] + eventQueueUnderTest.addSubscriber(1000) { + receivedSnapshots1 ++= _ + } + eventQueueUnderTest.addSubscriber(2000) { + receivedSnapshots2 ++= _ + } + + eventBufferScheduler.runUntilIdle() + assert(receivedSnapshots1 === Seq(ExecutorPodsSnapshot())) + assert(receivedSnapshots2 === Seq(ExecutorPodsSnapshot())) + + pushPodWithIndex(1) + // Force time to move forward so that the buffer is emitted, scheduling the + // processing task on the subscription executor... + eventBufferScheduler.tick(1000, TimeUnit.MILLISECONDS) + // ... then actually execute the subscribers. + + assert(receivedSnapshots1 === Seq( + ExecutorPodsSnapshot(), + ExecutorPodsSnapshot(Seq(podWithIndex(1))))) + assert(receivedSnapshots2 === Seq(ExecutorPodsSnapshot())) + + eventBufferScheduler.tick(1000, TimeUnit.MILLISECONDS) + + // Don't repeat snapshots + assert(receivedSnapshots1 === Seq( + ExecutorPodsSnapshot(), + ExecutorPodsSnapshot(Seq(podWithIndex(1))))) + assert(receivedSnapshots2 === Seq( + ExecutorPodsSnapshot(), + ExecutorPodsSnapshot(Seq(podWithIndex(1))))) + pushPodWithIndex(2) + pushPodWithIndex(3) + eventBufferScheduler.tick(1000, TimeUnit.MILLISECONDS) + + assert(receivedSnapshots1 === Seq( + ExecutorPodsSnapshot(), + ExecutorPodsSnapshot(Seq(podWithIndex(1))), + ExecutorPodsSnapshot(Seq(podWithIndex(1), podWithIndex(2))), + ExecutorPodsSnapshot(Seq(podWithIndex(1), podWithIndex(2), podWithIndex(3))))) + assert(receivedSnapshots2 === Seq( + ExecutorPodsSnapshot(), + ExecutorPodsSnapshot(Seq(podWithIndex(1))))) + + eventBufferScheduler.tick(1000, TimeUnit.MILLISECONDS) + assert(receivedSnapshots1 === Seq( + ExecutorPodsSnapshot(), + ExecutorPodsSnapshot(Seq(podWithIndex(1))), + ExecutorPodsSnapshot(Seq(podWithIndex(1), podWithIndex(2))), + ExecutorPodsSnapshot(Seq(podWithIndex(1), podWithIndex(2), podWithIndex(3))))) + assert(receivedSnapshots1 === receivedSnapshots2) + } + + test("Even without sending events, initially receive an empty buffer.") { + val receivedInitialSnapshot = new AtomicReference[Seq[ExecutorPodsSnapshot]](null) + eventQueueUnderTest.addSubscriber(1000) { + receivedInitialSnapshot.set + } + assert(receivedInitialSnapshot.get == null) + eventBufferScheduler.runUntilIdle() + assert(receivedInitialSnapshot.get === Seq(ExecutorPodsSnapshot())) + } + + test("Replacing the snapshot passes the new snapshot to subscribers.") { + val receivedSnapshots = mutable.Buffer.empty[ExecutorPodsSnapshot] + eventQueueUnderTest.addSubscriber(1000) { + receivedSnapshots ++= _ + } + eventQueueUnderTest.updatePod(podWithIndex(1)) + eventBufferScheduler.tick(1000, TimeUnit.MILLISECONDS) + assert(receivedSnapshots === Seq( + ExecutorPodsSnapshot(), + ExecutorPodsSnapshot(Seq(podWithIndex(1))))) + eventQueueUnderTest.replaceSnapshot(Seq(podWithIndex(2))) + eventBufferScheduler.tick(1000, TimeUnit.MILLISECONDS) + assert(receivedSnapshots === Seq( + ExecutorPodsSnapshot(), + ExecutorPodsSnapshot(Seq(podWithIndex(1))), + ExecutorPodsSnapshot(Seq(podWithIndex(2))))) + } + + private def pushPodWithIndex(index: Int): Unit = + eventQueueUnderTest.updatePod(podWithIndex(index)) + + private def podWithIndex(index: Int): Pod = + new PodBuilder() + .editOrNewMetadata() + .withName(s"pod-$index") + .addToLabels(SPARK_EXECUTOR_ID_LABEL, index.toString) + .endMetadata() + .editOrNewStatus() + .withPhase("running") + .endStatus() + .build() +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsWatchSnapshotSourceSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsWatchSnapshotSourceSuite.scala new file mode 100644 index 000000000000..ac1968b4ff81 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsWatchSnapshotSourceSuite.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import io.fabric8.kubernetes.api.model.Pod +import io.fabric8.kubernetes.client.{KubernetesClient, Watch, Watcher} +import io.fabric8.kubernetes.client.Watcher.Action +import org.mockito.{ArgumentCaptor, Mock, MockitoAnnotations} +import org.mockito.Mockito.{verify, when} +import org.scalatest.BeforeAndAfter + +import org.apache.spark.SparkFunSuite +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.Fabric8Aliases._ +import org.apache.spark.scheduler.cluster.k8s.ExecutorLifecycleTestUtils._ + +class ExecutorPodsWatchSnapshotSourceSuite extends SparkFunSuite with BeforeAndAfter { + + @Mock + private var eventQueue: ExecutorPodsSnapshotsStore = _ + + @Mock + private var kubernetesClient: KubernetesClient = _ + + @Mock + private var podOperations: PODS = _ + + @Mock + private var appIdLabeledPods: LABELED_PODS = _ + + @Mock + private var executorRoleLabeledPods: LABELED_PODS = _ + + @Mock + private var watchConnection: Watch = _ + + private var watch: ArgumentCaptor[Watcher[Pod]] = _ + + private var watchSourceUnderTest: ExecutorPodsWatchSnapshotSource = _ + + before { + MockitoAnnotations.initMocks(this) + watch = ArgumentCaptor.forClass(classOf[Watcher[Pod]]) + when(kubernetesClient.pods()).thenReturn(podOperations) + when(podOperations.withLabel(SPARK_APP_ID_LABEL, TEST_SPARK_APP_ID)) + .thenReturn(appIdLabeledPods) + when(appIdLabeledPods.withLabel(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE)) + .thenReturn(executorRoleLabeledPods) + when(executorRoleLabeledPods.watch(watch.capture())).thenReturn(watchConnection) + watchSourceUnderTest = new ExecutorPodsWatchSnapshotSource( + eventQueue, kubernetesClient) + watchSourceUnderTest.start(TEST_SPARK_APP_ID) + } + + test("Watch events should be pushed to the snapshots store as snapshot updates.") { + watch.getValue.eventReceived(Action.ADDED, runningExecutor(1)) + watch.getValue.eventReceived(Action.MODIFIED, runningExecutor(2)) + verify(eventQueue).updatePod(runningExecutor(1)) + verify(eventQueue).updatePod(runningExecutor(2)) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala index 96065e83f069..52e7a12dbaf0 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala @@ -16,85 +16,36 @@ */ package org.apache.spark.scheduler.cluster.k8s -import java.util.concurrent.{ExecutorService, ScheduledExecutorService, TimeUnit} - -import io.fabric8.kubernetes.api.model.{ContainerBuilder, DoneablePod, Pod, PodBuilder, PodList} -import io.fabric8.kubernetes.client.{KubernetesClient, Watch, Watcher} -import io.fabric8.kubernetes.client.Watcher.Action -import io.fabric8.kubernetes.client.dsl.{FilterWatchListDeletable, MixedOperation, NonNamespaceOperation, PodResource} -import org.hamcrest.{BaseMatcher, Description, Matcher} -import org.mockito.{AdditionalAnswers, ArgumentCaptor, Matchers, Mock, MockitoAnnotations} -import org.mockito.Matchers.{any, eq => mockitoEq} -import org.mockito.Mockito.{doNothing, never, times, verify, when} +import io.fabric8.kubernetes.client.KubernetesClient +import org.jmock.lib.concurrent.DeterministicScheduler +import org.mockito.{ArgumentCaptor, Mock, MockitoAnnotations} +import org.mockito.Matchers.{eq => mockitoEq} +import org.mockito.Mockito.{never, verify, when} import org.scalatest.BeforeAndAfter -import org.scalatest.mockito.MockitoSugar._ -import scala.collection.JavaConverters._ -import scala.concurrent.Future import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} -import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, SparkPod} -import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.rpc._ -import org.apache.spark.scheduler.{ExecutorExited, LiveListenerBus, SlaveLost, TaskSchedulerImpl} -import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.{RegisterExecutor, RemoveExecutor} +import org.apache.spark.deploy.k8s.Fabric8Aliases._ +import org.apache.spark.rpc.{RpcEndpoint, RpcEndpointRef, RpcEnv} +import org.apache.spark.scheduler.{ExecutorKilled, TaskSchedulerImpl} +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.RemoveExecutor import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend -import org.apache.spark.util.ThreadUtils +import org.apache.spark.scheduler.cluster.k8s.ExecutorLifecycleTestUtils.TEST_SPARK_APP_ID class KubernetesClusterSchedulerBackendSuite extends SparkFunSuite with BeforeAndAfter { - private val APP_ID = "test-spark-app" - private val DRIVER_POD_NAME = "spark-driver-pod" - private val NAMESPACE = "test-namespace" - private val SPARK_DRIVER_HOST = "localhost" - private val SPARK_DRIVER_PORT = 7077 - private val POD_ALLOCATION_INTERVAL = "1m" - private val FIRST_EXECUTOR_POD = new PodBuilder() - .withNewMetadata() - .withName("pod1") - .endMetadata() - .withNewSpec() - .withNodeName("node1") - .endSpec() - .withNewStatus() - .withHostIP("192.168.99.100") - .endStatus() - .build() - private val SECOND_EXECUTOR_POD = new PodBuilder() - .withNewMetadata() - .withName("pod2") - .endMetadata() - .withNewSpec() - .withNodeName("node2") - .endSpec() - .withNewStatus() - .withHostIP("192.168.99.101") - .endStatus() - .build() - - private type PODS = MixedOperation[Pod, PodList, DoneablePod, PodResource[Pod, DoneablePod]] - private type LABELED_PODS = FilterWatchListDeletable[ - Pod, PodList, java.lang.Boolean, Watch, Watcher[Pod]] - private type IN_NAMESPACE_PODS = NonNamespaceOperation[ - Pod, PodList, DoneablePod, PodResource[Pod, DoneablePod]] - - @Mock - private var sparkContext: SparkContext = _ - - @Mock - private var listenerBus: LiveListenerBus = _ - - @Mock - private var taskSchedulerImpl: TaskSchedulerImpl = _ + private val requestExecutorsService = new DeterministicScheduler() + private val sparkConf = new SparkConf(false) + .set("spark.executor.instances", "3") @Mock - private var allocatorExecutor: ScheduledExecutorService = _ + private var sc: SparkContext = _ @Mock - private var requestExecutorsService: ExecutorService = _ + private var rpcEnv: RpcEnv = _ @Mock - private var executorBuilder: KubernetesExecutorBuilder = _ + private var driverEndpointRef: RpcEndpointRef = _ @Mock private var kubernetesClient: KubernetesClient = _ @@ -103,347 +54,97 @@ class KubernetesClusterSchedulerBackendSuite extends SparkFunSuite with BeforeAn private var podOperations: PODS = _ @Mock - private var podsWithLabelOperations: LABELED_PODS = _ + private var labeledPods: LABELED_PODS = _ @Mock - private var podsInNamespace: IN_NAMESPACE_PODS = _ + private var taskScheduler: TaskSchedulerImpl = _ @Mock - private var podsWithDriverName: PodResource[Pod, DoneablePod] = _ + private var eventQueue: ExecutorPodsSnapshotsStore = _ @Mock - private var rpcEnv: RpcEnv = _ + private var podAllocator: ExecutorPodsAllocator = _ @Mock - private var driverEndpointRef: RpcEndpointRef = _ + private var lifecycleEventHandler: ExecutorPodsLifecycleManager = _ @Mock - private var executorPodsWatch: Watch = _ + private var watchEvents: ExecutorPodsWatchSnapshotSource = _ @Mock - private var successFuture: Future[Boolean] = _ + private var pollEvents: ExecutorPodsPollingSnapshotSource = _ - private var sparkConf: SparkConf = _ - private var executorPodsWatcherArgument: ArgumentCaptor[Watcher[Pod]] = _ - private var allocatorRunnable: ArgumentCaptor[Runnable] = _ - private var requestExecutorRunnable: ArgumentCaptor[Runnable] = _ private var driverEndpoint: ArgumentCaptor[RpcEndpoint] = _ - - private val driverPod = new PodBuilder() - .withNewMetadata() - .withName(DRIVER_POD_NAME) - .addToLabels(SPARK_APP_ID_LABEL, APP_ID) - .addToLabels(SPARK_ROLE_LABEL, SPARK_POD_DRIVER_ROLE) - .endMetadata() - .build() + private var schedulerBackendUnderTest: KubernetesClusterSchedulerBackend = _ before { MockitoAnnotations.initMocks(this) - sparkConf = new SparkConf() - .set(KUBERNETES_DRIVER_POD_NAME, DRIVER_POD_NAME) - .set(KUBERNETES_NAMESPACE, NAMESPACE) - .set("spark.driver.host", SPARK_DRIVER_HOST) - .set("spark.driver.port", SPARK_DRIVER_PORT.toString) - .set(KUBERNETES_ALLOCATION_BATCH_DELAY.key, POD_ALLOCATION_INTERVAL) - executorPodsWatcherArgument = ArgumentCaptor.forClass(classOf[Watcher[Pod]]) - allocatorRunnable = ArgumentCaptor.forClass(classOf[Runnable]) - requestExecutorRunnable = ArgumentCaptor.forClass(classOf[Runnable]) + when(taskScheduler.sc).thenReturn(sc) + when(sc.conf).thenReturn(sparkConf) driverEndpoint = ArgumentCaptor.forClass(classOf[RpcEndpoint]) - when(sparkContext.conf).thenReturn(sparkConf) - when(sparkContext.listenerBus).thenReturn(listenerBus) - when(taskSchedulerImpl.sc).thenReturn(sparkContext) - when(kubernetesClient.pods()).thenReturn(podOperations) - when(podOperations.withLabel(SPARK_APP_ID_LABEL, APP_ID)).thenReturn(podsWithLabelOperations) - when(podsWithLabelOperations.watch(executorPodsWatcherArgument.capture())) - .thenReturn(executorPodsWatch) - when(podOperations.inNamespace(NAMESPACE)).thenReturn(podsInNamespace) - when(podsInNamespace.withName(DRIVER_POD_NAME)).thenReturn(podsWithDriverName) - when(podsWithDriverName.get()).thenReturn(driverPod) - when(allocatorExecutor.scheduleWithFixedDelay( - allocatorRunnable.capture(), - mockitoEq(0L), - mockitoEq(TimeUnit.MINUTES.toMillis(1)), - mockitoEq(TimeUnit.MILLISECONDS))).thenReturn(null) - // Creating Futures in Scala backed by a Java executor service resolves to running - // ExecutorService#execute (as opposed to submit) - doNothing().when(requestExecutorsService).execute(requestExecutorRunnable.capture()) when(rpcEnv.setupEndpoint( mockitoEq(CoarseGrainedSchedulerBackend.ENDPOINT_NAME), driverEndpoint.capture())) .thenReturn(driverEndpointRef) - - // Used by the CoarseGrainedSchedulerBackend when making RPC calls. - when(driverEndpointRef.ask[Boolean] - (any(classOf[Any])) - (any())).thenReturn(successFuture) - when(successFuture.failed).thenReturn(Future[Throwable] { - // emulate behavior of the Future.failed method. - throw new NoSuchElementException() - }(ThreadUtils.sameThread)) - } - - test("Basic lifecycle expectations when starting and stopping the scheduler.") { - val scheduler = newSchedulerBackend() - scheduler.start() - assert(executorPodsWatcherArgument.getValue != null) - assert(allocatorRunnable.getValue != null) - scheduler.stop() - verify(executorPodsWatch).close() - } - - test("Static allocation should request executors upon first allocator run.") { - sparkConf - .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 2) - .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 2) - val scheduler = newSchedulerBackend() - scheduler.start() - requestExecutorRunnable.getValue.run() - val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) - val secondResolvedPod = expectPodCreationWithId(2, SECOND_EXECUTOR_POD) - when(podOperations.create(any(classOf[Pod]))).thenAnswer(AdditionalAnswers.returnsFirstArg()) - allocatorRunnable.getValue.run() - verify(podOperations).create(firstResolvedPod) - verify(podOperations).create(secondResolvedPod) - } - - test("Killing executors deletes the executor pods") { - sparkConf - .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 2) - .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 2) - val scheduler = newSchedulerBackend() - scheduler.start() - requestExecutorRunnable.getValue.run() - val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) - val secondResolvedPod = expectPodCreationWithId(2, SECOND_EXECUTOR_POD) - when(podOperations.create(any(classOf[Pod]))) - .thenAnswer(AdditionalAnswers.returnsFirstArg()) - allocatorRunnable.getValue.run() - scheduler.doKillExecutors(Seq("2")) - requestExecutorRunnable.getAllValues.asScala.last.run() - verify(podOperations).delete(secondResolvedPod) - verify(podOperations, never()).delete(firstResolvedPod) - } - - test("Executors should be requested in batches.") { - sparkConf - .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1) - .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 2) - val scheduler = newSchedulerBackend() - scheduler.start() - requestExecutorRunnable.getValue.run() - when(podOperations.create(any(classOf[Pod]))) - .thenAnswer(AdditionalAnswers.returnsFirstArg()) - val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) - val secondResolvedPod = expectPodCreationWithId(2, SECOND_EXECUTOR_POD) - allocatorRunnable.getValue.run() - verify(podOperations).create(firstResolvedPod) - verify(podOperations, never()).create(secondResolvedPod) - val registerFirstExecutorMessage = RegisterExecutor( - "1", mock[RpcEndpointRef], "localhost", 1, Map.empty[String, String]) - when(taskSchedulerImpl.resourceOffers(any())).thenReturn(Seq.empty) - driverEndpoint.getValue.receiveAndReply(mock[RpcCallContext]) - .apply(registerFirstExecutorMessage) - allocatorRunnable.getValue.run() - verify(podOperations).create(secondResolvedPod) - } - - test("Scaled down executors should be cleaned up") { - sparkConf - .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1) - .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 1) - val scheduler = newSchedulerBackend() - scheduler.start() - - // The scheduler backend spins up one executor pod. - requestExecutorRunnable.getValue.run() - when(podOperations.create(any(classOf[Pod]))) - .thenAnswer(AdditionalAnswers.returnsFirstArg()) - val resolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) - allocatorRunnable.getValue.run() - val executorEndpointRef = mock[RpcEndpointRef] - when(executorEndpointRef.address).thenReturn(RpcAddress("pod.example.com", 9000)) - val registerFirstExecutorMessage = RegisterExecutor( - "1", executorEndpointRef, "localhost:9000", 1, Map.empty[String, String]) - when(taskSchedulerImpl.resourceOffers(any())).thenReturn(Seq.empty) - driverEndpoint.getValue.receiveAndReply(mock[RpcCallContext]) - .apply(registerFirstExecutorMessage) - - // Request that there are 0 executors and trigger deletion from driver. - scheduler.doRequestTotalExecutors(0) - requestExecutorRunnable.getAllValues.asScala.last.run() - scheduler.doKillExecutors(Seq("1")) - requestExecutorRunnable.getAllValues.asScala.last.run() - verify(podOperations, times(1)).delete(resolvedPod) - driverEndpoint.getValue.onDisconnected(executorEndpointRef.address) - - val exitedPod = exitPod(resolvedPod, 0) - executorPodsWatcherArgument.getValue.eventReceived(Action.DELETED, exitedPod) - allocatorRunnable.getValue.run() - - // No more deletion attempts of the executors. - // This is graceful termination and should not be detected as a failure. - verify(podOperations, times(1)).delete(resolvedPod) - verify(driverEndpointRef, times(1)).send( - RemoveExecutor("1", ExecutorExited( - 0, - exitCausedByApp = false, - s"Container in pod ${exitedPod.getMetadata.getName} exited from" + - s" explicit termination request."))) - } - - test("Executors that fail should not be deleted.") { - sparkConf - .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1) - .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 1) - - val scheduler = newSchedulerBackend() - scheduler.start() - val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) - when(podOperations.create(any(classOf[Pod]))).thenAnswer(AdditionalAnswers.returnsFirstArg()) - requestExecutorRunnable.getValue.run() - allocatorRunnable.getValue.run() - val executorEndpointRef = mock[RpcEndpointRef] - when(executorEndpointRef.address).thenReturn(RpcAddress("pod.example.com", 9000)) - val registerFirstExecutorMessage = RegisterExecutor( - "1", executorEndpointRef, "localhost:9000", 1, Map.empty[String, String]) - when(taskSchedulerImpl.resourceOffers(any())).thenReturn(Seq.empty) - driverEndpoint.getValue.receiveAndReply(mock[RpcCallContext]) - .apply(registerFirstExecutorMessage) - driverEndpoint.getValue.onDisconnected(executorEndpointRef.address) - executorPodsWatcherArgument.getValue.eventReceived( - Action.ERROR, exitPod(firstResolvedPod, 1)) - - // A replacement executor should be created but the error pod should persist. - val replacementPod = expectPodCreationWithId(2, SECOND_EXECUTOR_POD) - scheduler.doRequestTotalExecutors(1) - requestExecutorRunnable.getValue.run() - allocatorRunnable.getAllValues.asScala.last.run() - verify(podOperations, never()).delete(firstResolvedPod) - verify(driverEndpointRef).send( - RemoveExecutor("1", ExecutorExited( - 1, - exitCausedByApp = true, - s"Pod ${FIRST_EXECUTOR_POD.getMetadata.getName}'s executor container exited with" + - " exit status code 1."))) - } - - test("Executors disconnected due to unknown reasons are deleted and replaced.") { - sparkConf - .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1) - .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 1) - val executorLostReasonCheckMaxAttempts = sparkConf.get( - KUBERNETES_EXECUTOR_LOST_REASON_CHECK_MAX_ATTEMPTS) - - val scheduler = newSchedulerBackend() - scheduler.start() - val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) - when(podOperations.create(any(classOf[Pod]))).thenAnswer(AdditionalAnswers.returnsFirstArg()) - requestExecutorRunnable.getValue.run() - allocatorRunnable.getValue.run() - val executorEndpointRef = mock[RpcEndpointRef] - when(executorEndpointRef.address).thenReturn(RpcAddress("pod.example.com", 9000)) - val registerFirstExecutorMessage = RegisterExecutor( - "1", executorEndpointRef, "localhost:9000", 1, Map.empty[String, String]) - when(taskSchedulerImpl.resourceOffers(any())).thenReturn(Seq.empty) - driverEndpoint.getValue.receiveAndReply(mock[RpcCallContext]) - .apply(registerFirstExecutorMessage) - - driverEndpoint.getValue.onDisconnected(executorEndpointRef.address) - 1 to executorLostReasonCheckMaxAttempts foreach { _ => - allocatorRunnable.getValue.run() - verify(podOperations, never()).delete(FIRST_EXECUTOR_POD) + when(kubernetesClient.pods()).thenReturn(podOperations) + schedulerBackendUnderTest = new KubernetesClusterSchedulerBackend( + taskScheduler, + rpcEnv, + kubernetesClient, + requestExecutorsService, + eventQueue, + podAllocator, + lifecycleEventHandler, + watchEvents, + pollEvents) { + override def applicationId(): String = TEST_SPARK_APP_ID } - - val recreatedResolvedPod = expectPodCreationWithId(2, SECOND_EXECUTOR_POD) - allocatorRunnable.getValue.run() - verify(podOperations).delete(firstResolvedPod) - verify(driverEndpointRef).send( - RemoveExecutor("1", SlaveLost("Executor lost for unknown reasons."))) } - test("Executors that fail to start on the Kubernetes API call rebuild in the next batch.") { - sparkConf - .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1) - .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 1) - val scheduler = newSchedulerBackend() - scheduler.start() - val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) - when(podOperations.create(firstResolvedPod)) - .thenThrow(new RuntimeException("test")) - requestExecutorRunnable.getValue.run() - allocatorRunnable.getValue.run() - verify(podOperations, times(1)).create(firstResolvedPod) - val recreatedResolvedPod = expectPodCreationWithId(2, FIRST_EXECUTOR_POD) - allocatorRunnable.getValue.run() - verify(podOperations).create(recreatedResolvedPod) + test("Start all components") { + schedulerBackendUnderTest.start() + verify(podAllocator).setTotalExpectedExecutors(3) + verify(podAllocator).start(TEST_SPARK_APP_ID) + verify(lifecycleEventHandler).start(schedulerBackendUnderTest) + verify(watchEvents).start(TEST_SPARK_APP_ID) + verify(pollEvents).start(TEST_SPARK_APP_ID) } - test("Executors that are initially created but the watch notices them fail are rebuilt" + - " in the next batch.") { - sparkConf - .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1) - .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 1) - val scheduler = newSchedulerBackend() - scheduler.start() - val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) - when(podOperations.create(FIRST_EXECUTOR_POD)).thenAnswer(AdditionalAnswers.returnsFirstArg()) - requestExecutorRunnable.getValue.run() - allocatorRunnable.getValue.run() - verify(podOperations, times(1)).create(firstResolvedPod) - executorPodsWatcherArgument.getValue.eventReceived(Action.ERROR, firstResolvedPod) - val recreatedResolvedPod = expectPodCreationWithId(2, FIRST_EXECUTOR_POD) - allocatorRunnable.getValue.run() - verify(podOperations).create(recreatedResolvedPod) + test("Stop all components") { + when(podOperations.withLabel(SPARK_APP_ID_LABEL, TEST_SPARK_APP_ID)).thenReturn(labeledPods) + when(labeledPods.withLabel(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE)).thenReturn(labeledPods) + schedulerBackendUnderTest.stop() + verify(eventQueue).stop() + verify(watchEvents).stop() + verify(pollEvents).stop() + verify(labeledPods).delete() + verify(kubernetesClient).close() } - private def newSchedulerBackend(): KubernetesClusterSchedulerBackend = { - new KubernetesClusterSchedulerBackend( - taskSchedulerImpl, - rpcEnv, - executorBuilder, - kubernetesClient, - allocatorExecutor, - requestExecutorsService) { - - override def applicationId(): String = APP_ID - } + test("Remove executor") { + schedulerBackendUnderTest.start() + schedulerBackendUnderTest.doRemoveExecutor( + "1", ExecutorKilled) + verify(driverEndpointRef).send(RemoveExecutor("1", ExecutorKilled)) } - private def exitPod(basePod: Pod, exitCode: Int): Pod = { - new PodBuilder(basePod) - .editStatus() - .addNewContainerStatus() - .withNewState() - .withNewTerminated() - .withExitCode(exitCode) - .endTerminated() - .endState() - .endContainerStatus() - .endStatus() - .build() + test("Kill executors") { + schedulerBackendUnderTest.start() + when(podOperations.withLabel(SPARK_APP_ID_LABEL, TEST_SPARK_APP_ID)).thenReturn(labeledPods) + when(labeledPods.withLabel(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE)).thenReturn(labeledPods) + when(labeledPods.withLabelIn(SPARK_EXECUTOR_ID_LABEL, "1", "2")).thenReturn(labeledPods) + schedulerBackendUnderTest.doKillExecutors(Seq("1", "2")) + verify(labeledPods, never()).delete() + requestExecutorsService.runNextPendingCommand() + verify(labeledPods).delete() } - private def expectPodCreationWithId(executorId: Int, expectedPod: Pod): Pod = { - val resolvedPod = new PodBuilder(expectedPod) - .editMetadata() - .addToLabels(SPARK_EXECUTOR_ID_LABEL, executorId.toString) - .endMetadata() - .build() - val resolvedContainer = new ContainerBuilder().build() - when(executorBuilder.buildFromFeatures(Matchers.argThat( - new BaseMatcher[KubernetesConf[KubernetesExecutorSpecificConf]] { - override def matches(argument: scala.Any) - : Boolean = { - argument.isInstanceOf[KubernetesConf[KubernetesExecutorSpecificConf]] && - argument.asInstanceOf[KubernetesConf[KubernetesExecutorSpecificConf]] - .roleSpecificConf.executorId == executorId.toString - } - - override def describeTo(description: Description): Unit = {} - }))).thenReturn(SparkPod(resolvedPod, resolvedContainer)) - new PodBuilder(resolvedPod) - .editSpec() - .addToContainers(resolvedContainer) - .endSpec() - .build() + test("Request total executors") { + schedulerBackendUnderTest.start() + schedulerBackendUnderTest.doRequestTotalExecutors(5) + verify(podAllocator).setTotalExpectedExecutors(3) + verify(podAllocator, never()).setTotalExpectedExecutors(5) + requestExecutorsService.runNextPendingCommand() + verify(podAllocator).setTotalExpectedExecutors(5) } + } diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala index 604978967d6d..15bbe60d6c8f 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala @@ -40,7 +40,7 @@ private[spark] class MesosClusterUI( override def initialize() { attachPage(new MesosClusterPage(this)) attachPage(new DriverPage(this)) - attachHandler(createStaticHandler(MesosClusterUI.STATIC_RESOURCE_DIR, "/static")) + addStaticHandler(MesosClusterUI.STATIC_RESOURCE_DIR) } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index d5d934bc91ca..4dd2b7365652 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -83,7 +83,7 @@ public static long calculateSizeOfUnderlyingByteArray(long numFields, int elemen private long elementOffset; private long getElementOffset(int ordinal, int elementSize) { - return elementOffset + ordinal * elementSize; + return elementOffset + ordinal * (long)elementSize; } public Object getBaseObject() { return baseObject; } @@ -414,7 +414,7 @@ public byte[] toByteArray() { public short[] toShortArray() { short[] values = new short[numElements]; Platform.copyMemory( - baseObject, elementOffset, values, Platform.SHORT_ARRAY_OFFSET, numElements * 2); + baseObject, elementOffset, values, Platform.SHORT_ARRAY_OFFSET, numElements * 2L); return values; } @@ -422,7 +422,7 @@ public short[] toShortArray() { public int[] toIntArray() { int[] values = new int[numElements]; Platform.copyMemory( - baseObject, elementOffset, values, Platform.INT_ARRAY_OFFSET, numElements * 4); + baseObject, elementOffset, values, Platform.INT_ARRAY_OFFSET, numElements * 4L); return values; } @@ -430,7 +430,7 @@ public int[] toIntArray() { public long[] toLongArray() { long[] values = new long[numElements]; Platform.copyMemory( - baseObject, elementOffset, values, Platform.LONG_ARRAY_OFFSET, numElements * 8); + baseObject, elementOffset, values, Platform.LONG_ARRAY_OFFSET, numElements * 8L); return values; } @@ -438,7 +438,7 @@ public long[] toLongArray() { public float[] toFloatArray() { float[] values = new float[numElements]; Platform.copyMemory( - baseObject, elementOffset, values, Platform.FLOAT_ARRAY_OFFSET, numElements * 4); + baseObject, elementOffset, values, Platform.FLOAT_ARRAY_OFFSET, numElements * 4L); return values; } @@ -446,14 +446,14 @@ public float[] toFloatArray() { public double[] toDoubleArray() { double[] values = new double[numElements]; Platform.copyMemory( - baseObject, elementOffset, values, Platform.DOUBLE_ARRAY_OFFSET, numElements * 8); + baseObject, elementOffset, values, Platform.DOUBLE_ARRAY_OFFSET, numElements * 8L); return values; } private static UnsafeArrayData fromPrimitiveArray( Object arr, int offset, int length, int elementSize) { final long headerInBytes = calculateHeaderPortionInBytes(length); - final long valueRegionInBytes = elementSize * length; + final long valueRegionInBytes = (long)elementSize * length; final long totalSizeInLongs = (headerInBytes + valueRegionInBytes + 7) / 8; if (totalSizeInLongs > Integer.MAX_VALUE / 8) { throw new UnsupportedOperationException("Cannot convert this array to unsafe format as " + diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/VariableLengthRowBasedKeyValueBatch.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/VariableLengthRowBasedKeyValueBatch.java index 905e6820ce6e..c823de4810f2 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/VariableLengthRowBasedKeyValueBatch.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/VariableLengthRowBasedKeyValueBatch.java @@ -41,7 +41,7 @@ public final class VariableLengthRowBasedKeyValueBatch extends RowBasedKeyValueB @Override public UnsafeRow appendRow(Object kbase, long koff, int klen, Object vbase, long voff, int vlen) { - final long recordLength = 8 + klen + vlen + 8; + final long recordLength = 8L + klen + vlen + 8; // if run out of max supported rows or page size, return null if (numRows >= capacity || page == null || page.size() - pageCursor < recordLength) { return null; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 9e9105a157ab..93df73ab1eaf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -286,6 +286,7 @@ object CatalystTypeConverters { override def toCatalystImpl(scalaValue: Any): UTF8String = scalaValue match { case str: String => UTF8String.fromString(str) case utf8: UTF8String => utf8 + case chr: Char => UTF8String.fromString(chr.toString) case other => throw new IllegalArgumentException( s"The value (${other.toString}) of the type (${other.getClass.getCanonicalName}) " + s"cannot be converted to the string type") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 04a4eb0ffc03..f6d74f5b74c8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -28,7 +28,6 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.json._ -import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, BadRecordException, FailFastMode, GenericArrayData, MapData} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -747,8 +746,8 @@ case class StructsToJson( object JsonExprUtils { - def validateSchemaLiteral(exp: Expression): StructType = exp match { - case Literal(s, StringType) => CatalystSqlParser.parseTableSchema(s.toString) + def validateSchemaLiteral(exp: Expression): DataType = exp match { + case Literal(s, StringType) => DataType.fromDDL(s.toString) case e => throw new AnalysisException(s"Expected a string literal instead of $e") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 246025b82d59..0cc2a332f2c3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -57,6 +57,7 @@ object Literal { case b: Byte => Literal(b, ByteType) case s: Short => Literal(s, ShortType) case s: String => Literal(UTF8String.fromString(s), StringType) + case c: Char => Literal(UTF8String.fromString(c.toString), StringType) case b: Boolean => Literal(b, BooleanType) case d: BigDecimal => Literal(Decimal(d), DecimalType.fromBigDecimal(d)) case d: JavaBigDecimal => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 0bef11659fc9..fd40741cfb5f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.types import java.util.Locale +import scala.util.control.NonFatal + import org.json4s._ import org.json4s.JsonAST.JValue import org.json4s.JsonDSL._ @@ -26,6 +28,7 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.Utils @@ -110,6 +113,14 @@ abstract class DataType extends AbstractDataType { @InterfaceStability.Stable object DataType { + def fromDDL(ddl: String): DataType = { + try { + CatalystSqlParser.parseDataType(ddl) + } catch { + case NonFatal(_) => CatalystSqlParser.parseTableSchema(ddl) + } + } + def fromJson(json: String): DataType = parseDataType(parse(json)) private val nonDecimalNameToType = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala index f99af9b84d95..89452ee05cff 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String class CatalystTypeConvertersSuite extends SparkFunSuite { @@ -139,4 +140,11 @@ class CatalystTypeConvertersSuite extends SparkFunSuite { assert(exception.getMessage.contains("The value (0.1) of the type " + "(java.lang.Double) cannot be converted to the string type")) } + + test("SPARK-24571: convert Char to String") { + val chr: Char = 'X' + val converter = CatalystTypeConverters.createToCatalystConverter(StringType) + val expected = UTF8String.fromString("X") + assert(converter(chr) === expected) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala index a9e0eb0e377a..86f80fe66d28 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala @@ -219,4 +219,11 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { checkUnsupportedTypeInLiteral(Map("key1" -> 1, "key2" -> 2)) checkUnsupportedTypeInLiteral(("mike", 29, 1.0)) } + + test("SPARK-24571: char literals") { + checkEvaluation(Literal('X'), "X") + checkEvaluation(Literal.create('0'), "0") + checkEvaluation(Literal('\u0000'), "\u0000") + checkEvaluation(Literal.create('\n'), "\n") + } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index 4733f36174f4..6fdadde62855 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -216,12 +216,12 @@ protected UTF8String getBytesAsUTF8String(int rowId, int count) { @Override public void putShort(int rowId, short value) { - Platform.putShort(null, data + 2 * rowId, value); + Platform.putShort(null, data + 2L * rowId, value); } @Override public void putShorts(int rowId, int count, short value) { - long offset = data + 2 * rowId; + long offset = data + 2L * rowId; for (int i = 0; i < count; ++i, offset += 2) { Platform.putShort(null, offset, value); } @@ -229,20 +229,20 @@ public void putShorts(int rowId, int count, short value) { @Override public void putShorts(int rowId, int count, short[] src, int srcIndex) { - Platform.copyMemory(src, Platform.SHORT_ARRAY_OFFSET + srcIndex * 2, - null, data + 2 * rowId, count * 2); + Platform.copyMemory(src, Platform.SHORT_ARRAY_OFFSET + srcIndex * 2L, + null, data + 2L * rowId, count * 2L); } @Override public void putShorts(int rowId, int count, byte[] src, int srcIndex) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, - null, data + rowId * 2, count * 2); + null, data + rowId * 2L, count * 2L); } @Override public short getShort(int rowId) { if (dictionary == null) { - return Platform.getShort(null, data + 2 * rowId); + return Platform.getShort(null, data + 2L * rowId); } else { return (short) dictionary.decodeToInt(dictionaryIds.getDictId(rowId)); } @@ -252,7 +252,7 @@ public short getShort(int rowId) { public short[] getShorts(int rowId, int count) { assert(dictionary == null); short[] array = new short[count]; - Platform.copyMemory(null, data + rowId * 2, array, Platform.SHORT_ARRAY_OFFSET, count * 2); + Platform.copyMemory(null, data + rowId * 2L, array, Platform.SHORT_ARRAY_OFFSET, count * 2L); return array; } @@ -262,12 +262,12 @@ public short[] getShorts(int rowId, int count) { @Override public void putInt(int rowId, int value) { - Platform.putInt(null, data + 4 * rowId, value); + Platform.putInt(null, data + 4L * rowId, value); } @Override public void putInts(int rowId, int count, int value) { - long offset = data + 4 * rowId; + long offset = data + 4L * rowId; for (int i = 0; i < count; ++i, offset += 4) { Platform.putInt(null, offset, value); } @@ -275,24 +275,24 @@ public void putInts(int rowId, int count, int value) { @Override public void putInts(int rowId, int count, int[] src, int srcIndex) { - Platform.copyMemory(src, Platform.INT_ARRAY_OFFSET + srcIndex * 4, - null, data + 4 * rowId, count * 4); + Platform.copyMemory(src, Platform.INT_ARRAY_OFFSET + srcIndex * 4L, + null, data + 4L * rowId, count * 4L); } @Override public void putInts(int rowId, int count, byte[] src, int srcIndex) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, - null, data + rowId * 4, count * 4); + null, data + rowId * 4L, count * 4L); } @Override public void putIntsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { if (!bigEndianPlatform) { Platform.copyMemory(src, srcIndex + Platform.BYTE_ARRAY_OFFSET, - null, data + 4 * rowId, count * 4); + null, data + 4L * rowId, count * 4L); } else { int srcOffset = srcIndex + Platform.BYTE_ARRAY_OFFSET; - long offset = data + 4 * rowId; + long offset = data + 4L * rowId; for (int i = 0; i < count; ++i, offset += 4, srcOffset += 4) { Platform.putInt(null, offset, java.lang.Integer.reverseBytes(Platform.getInt(src, srcOffset))); @@ -303,7 +303,7 @@ public void putIntsLittleEndian(int rowId, int count, byte[] src, int srcIndex) @Override public int getInt(int rowId) { if (dictionary == null) { - return Platform.getInt(null, data + 4 * rowId); + return Platform.getInt(null, data + 4L * rowId); } else { return dictionary.decodeToInt(dictionaryIds.getDictId(rowId)); } @@ -313,7 +313,7 @@ public int getInt(int rowId) { public int[] getInts(int rowId, int count) { assert(dictionary == null); int[] array = new int[count]; - Platform.copyMemory(null, data + rowId * 4, array, Platform.INT_ARRAY_OFFSET, count * 4); + Platform.copyMemory(null, data + rowId * 4L, array, Platform.INT_ARRAY_OFFSET, count * 4L); return array; } @@ -325,7 +325,7 @@ public int[] getInts(int rowId, int count) { public int getDictId(int rowId) { assert(dictionary == null) : "A ColumnVector dictionary should not have a dictionary for itself."; - return Platform.getInt(null, data + 4 * rowId); + return Platform.getInt(null, data + 4L * rowId); } // @@ -334,12 +334,12 @@ public int getDictId(int rowId) { @Override public void putLong(int rowId, long value) { - Platform.putLong(null, data + 8 * rowId, value); + Platform.putLong(null, data + 8L * rowId, value); } @Override public void putLongs(int rowId, int count, long value) { - long offset = data + 8 * rowId; + long offset = data + 8L * rowId; for (int i = 0; i < count; ++i, offset += 8) { Platform.putLong(null, offset, value); } @@ -347,24 +347,24 @@ public void putLongs(int rowId, int count, long value) { @Override public void putLongs(int rowId, int count, long[] src, int srcIndex) { - Platform.copyMemory(src, Platform.LONG_ARRAY_OFFSET + srcIndex * 8, - null, data + 8 * rowId, count * 8); + Platform.copyMemory(src, Platform.LONG_ARRAY_OFFSET + srcIndex * 8L, + null, data + 8L * rowId, count * 8L); } @Override public void putLongs(int rowId, int count, byte[] src, int srcIndex) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, - null, data + rowId * 8, count * 8); + null, data + rowId * 8L, count * 8L); } @Override public void putLongsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { if (!bigEndianPlatform) { Platform.copyMemory(src, srcIndex + Platform.BYTE_ARRAY_OFFSET, - null, data + 8 * rowId, count * 8); + null, data + 8L * rowId, count * 8L); } else { int srcOffset = srcIndex + Platform.BYTE_ARRAY_OFFSET; - long offset = data + 8 * rowId; + long offset = data + 8L * rowId; for (int i = 0; i < count; ++i, offset += 8, srcOffset += 8) { Platform.putLong(null, offset, java.lang.Long.reverseBytes(Platform.getLong(src, srcOffset))); @@ -375,7 +375,7 @@ public void putLongsLittleEndian(int rowId, int count, byte[] src, int srcIndex) @Override public long getLong(int rowId) { if (dictionary == null) { - return Platform.getLong(null, data + 8 * rowId); + return Platform.getLong(null, data + 8L * rowId); } else { return dictionary.decodeToLong(dictionaryIds.getDictId(rowId)); } @@ -385,7 +385,7 @@ public long getLong(int rowId) { public long[] getLongs(int rowId, int count) { assert(dictionary == null); long[] array = new long[count]; - Platform.copyMemory(null, data + rowId * 8, array, Platform.LONG_ARRAY_OFFSET, count * 8); + Platform.copyMemory(null, data + rowId * 8L, array, Platform.LONG_ARRAY_OFFSET, count * 8L); return array; } @@ -395,12 +395,12 @@ public long[] getLongs(int rowId, int count) { @Override public void putFloat(int rowId, float value) { - Platform.putFloat(null, data + rowId * 4, value); + Platform.putFloat(null, data + rowId * 4L, value); } @Override public void putFloats(int rowId, int count, float value) { - long offset = data + 4 * rowId; + long offset = data + 4L * rowId; for (int i = 0; i < count; ++i, offset += 4) { Platform.putFloat(null, offset, value); } @@ -408,18 +408,18 @@ public void putFloats(int rowId, int count, float value) { @Override public void putFloats(int rowId, int count, float[] src, int srcIndex) { - Platform.copyMemory(src, Platform.FLOAT_ARRAY_OFFSET + srcIndex * 4, - null, data + 4 * rowId, count * 4); + Platform.copyMemory(src, Platform.FLOAT_ARRAY_OFFSET + srcIndex * 4L, + null, data + 4L * rowId, count * 4L); } @Override public void putFloats(int rowId, int count, byte[] src, int srcIndex) { if (!bigEndianPlatform) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, - null, data + rowId * 4, count * 4); + null, data + rowId * 4L, count * 4L); } else { ByteBuffer bb = ByteBuffer.wrap(src).order(ByteOrder.LITTLE_ENDIAN); - long offset = data + 4 * rowId; + long offset = data + 4L * rowId; for (int i = 0; i < count; ++i, offset += 4) { Platform.putFloat(null, offset, bb.getFloat(srcIndex + (4 * i))); } @@ -429,7 +429,7 @@ public void putFloats(int rowId, int count, byte[] src, int srcIndex) { @Override public float getFloat(int rowId) { if (dictionary == null) { - return Platform.getFloat(null, data + rowId * 4); + return Platform.getFloat(null, data + rowId * 4L); } else { return dictionary.decodeToFloat(dictionaryIds.getDictId(rowId)); } @@ -439,7 +439,7 @@ public float getFloat(int rowId) { public float[] getFloats(int rowId, int count) { assert(dictionary == null); float[] array = new float[count]; - Platform.copyMemory(null, data + rowId * 4, array, Platform.FLOAT_ARRAY_OFFSET, count * 4); + Platform.copyMemory(null, data + rowId * 4L, array, Platform.FLOAT_ARRAY_OFFSET, count * 4L); return array; } @@ -450,12 +450,12 @@ public float[] getFloats(int rowId, int count) { @Override public void putDouble(int rowId, double value) { - Platform.putDouble(null, data + rowId * 8, value); + Platform.putDouble(null, data + rowId * 8L, value); } @Override public void putDoubles(int rowId, int count, double value) { - long offset = data + 8 * rowId; + long offset = data + 8L * rowId; for (int i = 0; i < count; ++i, offset += 8) { Platform.putDouble(null, offset, value); } @@ -463,18 +463,18 @@ public void putDoubles(int rowId, int count, double value) { @Override public void putDoubles(int rowId, int count, double[] src, int srcIndex) { - Platform.copyMemory(src, Platform.DOUBLE_ARRAY_OFFSET + srcIndex * 8, - null, data + 8 * rowId, count * 8); + Platform.copyMemory(src, Platform.DOUBLE_ARRAY_OFFSET + srcIndex * 8L, + null, data + 8L * rowId, count * 8L); } @Override public void putDoubles(int rowId, int count, byte[] src, int srcIndex) { if (!bigEndianPlatform) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, - null, data + rowId * 8, count * 8); + null, data + rowId * 8L, count * 8L); } else { ByteBuffer bb = ByteBuffer.wrap(src).order(ByteOrder.LITTLE_ENDIAN); - long offset = data + 8 * rowId; + long offset = data + 8L * rowId; for (int i = 0; i < count; ++i, offset += 8) { Platform.putDouble(null, offset, bb.getDouble(srcIndex + (8 * i))); } @@ -484,7 +484,7 @@ public void putDoubles(int rowId, int count, byte[] src, int srcIndex) { @Override public double getDouble(int rowId) { if (dictionary == null) { - return Platform.getDouble(null, data + rowId * 8); + return Platform.getDouble(null, data + rowId * 8L); } else { return dictionary.decodeToDouble(dictionaryIds.getDictId(rowId)); } @@ -494,7 +494,7 @@ public double getDouble(int rowId) { public double[] getDoubles(int rowId, int count) { assert(dictionary == null); double[] array = new double[count]; - Platform.copyMemory(null, data + rowId * 8, array, Platform.DOUBLE_ARRAY_OFFSET, count * 8); + Platform.copyMemory(null, data + rowId * 8L, array, Platform.DOUBLE_ARRAY_OFFSET, count * 8L); return array; } @@ -504,26 +504,26 @@ public double[] getDoubles(int rowId, int count) { @Override public void putArray(int rowId, int offset, int length) { assert(offset >= 0 && offset + length <= childColumns[0].capacity); - Platform.putInt(null, lengthData + 4 * rowId, length); - Platform.putInt(null, offsetData + 4 * rowId, offset); + Platform.putInt(null, lengthData + 4L * rowId, length); + Platform.putInt(null, offsetData + 4L * rowId, offset); } @Override public int getArrayLength(int rowId) { - return Platform.getInt(null, lengthData + 4 * rowId); + return Platform.getInt(null, lengthData + 4L * rowId); } @Override public int getArrayOffset(int rowId) { - return Platform.getInt(null, offsetData + 4 * rowId); + return Platform.getInt(null, offsetData + 4L * rowId); } // APIs dealing with ByteArrays @Override public int putByteArray(int rowId, byte[] value, int offset, int length) { int result = arrayData().appendBytes(length, value, offset); - Platform.putInt(null, lengthData + 4 * rowId, length); - Platform.putInt(null, offsetData + 4 * rowId, result); + Platform.putInt(null, lengthData + 4L * rowId, length); + Platform.putInt(null, offsetData + 4L * rowId, result); return result; } @@ -533,19 +533,19 @@ protected void reserveInternal(int newCapacity) { int oldCapacity = (nulls == 0L) ? 0 : capacity; if (isArray() || type instanceof MapType) { this.lengthData = - Platform.reallocateMemory(lengthData, oldCapacity * 4, newCapacity * 4); + Platform.reallocateMemory(lengthData, oldCapacity * 4L, newCapacity * 4L); this.offsetData = - Platform.reallocateMemory(offsetData, oldCapacity * 4, newCapacity * 4); + Platform.reallocateMemory(offsetData, oldCapacity * 4L, newCapacity * 4L); } else if (type instanceof ByteType || type instanceof BooleanType) { this.data = Platform.reallocateMemory(data, oldCapacity, newCapacity); } else if (type instanceof ShortType) { - this.data = Platform.reallocateMemory(data, oldCapacity * 2, newCapacity * 2); + this.data = Platform.reallocateMemory(data, oldCapacity * 2L, newCapacity * 2L); } else if (type instanceof IntegerType || type instanceof FloatType || type instanceof DateType || DecimalType.is32BitDecimalType(type)) { - this.data = Platform.reallocateMemory(data, oldCapacity * 4, newCapacity * 4); + this.data = Platform.reallocateMemory(data, oldCapacity * 4L, newCapacity * 4L); } else if (type instanceof LongType || type instanceof DoubleType || DecimalType.is64BitDecimalType(type) || type instanceof TimestampType) { - this.data = Platform.reallocateMemory(data, oldCapacity * 8, newCapacity * 8); + this.data = Platform.reallocateMemory(data, oldCapacity * 8L, newCapacity * 8L); } else if (childColumns != null) { // Nothing to store. } else { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index 23dcc104e67c..577eab6ed14c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -231,7 +231,7 @@ public void putShorts(int rowId, int count, short[] src, int srcIndex) { @Override public void putShorts(int rowId, int count, byte[] src, int srcIndex) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, shortData, - Platform.SHORT_ARRAY_OFFSET + rowId * 2, count * 2); + Platform.SHORT_ARRAY_OFFSET + rowId * 2L, count * 2L); } @Override @@ -276,7 +276,7 @@ public void putInts(int rowId, int count, int[] src, int srcIndex) { @Override public void putInts(int rowId, int count, byte[] src, int srcIndex) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, intData, - Platform.INT_ARRAY_OFFSET + rowId * 4, count * 4); + Platform.INT_ARRAY_OFFSET + rowId * 4L, count * 4L); } @Override @@ -342,7 +342,7 @@ public void putLongs(int rowId, int count, long[] src, int srcIndex) { @Override public void putLongs(int rowId, int count, byte[] src, int srcIndex) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, longData, - Platform.LONG_ARRAY_OFFSET + rowId * 8, count * 8); + Platform.LONG_ARRAY_OFFSET + rowId * 8L, count * 8L); } @Override @@ -394,7 +394,7 @@ public void putFloats(int rowId, int count, float[] src, int srcIndex) { public void putFloats(int rowId, int count, byte[] src, int srcIndex) { if (!bigEndianPlatform) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, floatData, - Platform.DOUBLE_ARRAY_OFFSET + rowId * 4, count * 4); + Platform.DOUBLE_ARRAY_OFFSET + rowId * 4L, count * 4L); } else { ByteBuffer bb = ByteBuffer.wrap(src).order(ByteOrder.LITTLE_ENDIAN); for (int i = 0; i < count; ++i) { @@ -443,7 +443,7 @@ public void putDoubles(int rowId, int count, double[] src, int srcIndex) { public void putDoubles(int rowId, int count, byte[] src, int srcIndex) { if (!bigEndianPlatform) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, doubleData, - Platform.DOUBLE_ARRAY_OFFSET + rowId * 8, count * 8); + Platform.DOUBLE_ARRAY_OFFSET + rowId * 8L, count * 8L); } else { ByteBuffer bb = ByteBuffer.wrap(src).order(ByteOrder.LITTLE_ENDIAN); for (int i = 0; i < count; ++i) { diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java index 11bb13fd3b21..a79080a249ec 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java @@ -22,6 +22,11 @@ /** * A mix in interface for {@link DataSourceReader}. Data source readers can implement this * interface to report statistics to Spark. + * + * Statistics are reported to the optimizer before a projection or any filters are pushed to the + * DataSourceReader. Implementations that return more accurate statistics based on projection and + * filters will not improve query performance until the planner can push operators before getting + * stats. */ @InterfaceStability.Evolving public interface SupportsReportStatistics extends DataSourceReader { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala index 86e02e98c01f..b21c50af1843 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala @@ -20,10 +20,48 @@ package org.apache.spark.sql import org.apache.spark.annotation.InterfaceStability /** - * A class to consume data generated by a `StreamingQuery`. Typically this is used to send the - * generated data to external systems. Each partition will use a new deserialized instance, so you - * usually should do all the initialization (e.g. opening a connection or initiating a transaction) - * in the `open` method. + * The abstract class for writing custom logic to process data generated by a query. + * This is often used to write the output of a streaming query to arbitrary storage systems. + * Any implementation of this base class will be used by Spark in the following way. + * + *
    + *
  • A single instance of this class is responsible of all the data generated by a single task + * in a query. In other words, one instance is responsible for processing one partition of the + * data generated in a distributed manner. + * + *
  • Any implementation of this class must be serializable because each task will get a fresh + * serialized-deserialized copy of the provided object. Hence, it is strongly recommended that + * any initialization for writing data (e.g. opening a connection or starting a transaction) + * is done after the `open(...)` method has been called, which signifies that the task is + * ready to generate data. + * + *
  • The lifecycle of the methods are as follows. + * + *
    + *   For each partition with `partitionId`:
    + *       For each batch/epoch of streaming data (if its streaming query) with `epochId`:
    + *           Method `open(partitionId, epochId)` is called.
    + *           If `open` returns true:
    + *                For each row in the partition and batch/epoch, method `process(row)` is called.
    + *           Method `close(errorOrNull)` is called with error (if any) seen while processing rows.
    + *   
    + * + *
+ * + * Important points to note: + *
    + *
  • The `partitionId` and `epochId` can be used to deduplicate generated data when failures + * cause reprocessing of some input data. This depends on the execution mode of the query. If + * the streaming query is being executed in the micro-batch mode, then every partition + * represented by a unique tuple (partitionId, epochId) is guaranteed to have the same data. + * Hence, (partitionId, epochId) can be used to deduplicate and/or transactionally commit data + * and achieve exactly-once guarantees. However, if the streaming query is being executed in the + * continuous mode, then this guarantee does not hold and therefore should not be used for + * deduplication. + * + *
  • The `close()` method will be called if `open()` method returns successfully (irrespective + * of the return value), except if the JVM crashes in the middle. + *
* * Scala example: * {{{ @@ -63,6 +101,7 @@ import org.apache.spark.annotation.InterfaceStability * } * }); * }}} + * * @since 2.0.0 */ @InterfaceStability.Evolving @@ -71,23 +110,18 @@ abstract class ForeachWriter[T] extends Serializable { // TODO: Move this to org.apache.spark.sql.util or consolidate this with batch API. /** - * Called when starting to process one partition of new data in the executor. The `version` is - * for data deduplication when there are failures. When recovering from a failure, some data may - * be generated multiple times but they will always have the same version. - * - * If this method finds using the `partitionId` and `version` that this partition has already been - * processed, it can return `false` to skip the further data processing. However, `close` still - * will be called for cleaning up resources. + * Called when starting to process one partition of new data in the executor. See the class + * docs for more information on how to use the `partitionId` and `epochId`. * * @param partitionId the partition id. - * @param version a unique id for data deduplication. + * @param epochId a unique id for data deduplication. * @return `true` if the corresponding partition and version id should be processed. `false` * indicates the partition should be skipped. */ - def open(partitionId: Long, version: Long): Boolean + def open(partitionId: Long, epochId: Long): Boolean /** - * Called to process the data in the executor side. This method will be called only when `open` + * Called to process the data in the executor side. This method will be called only if `open` * returns `true`. */ def process(value: T): Unit diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index 1c8e4050978d..00ff4c8ac310 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -21,7 +21,6 @@ import org.apache.spark.sql.ExperimentalMethods import org.apache.spark.sql.catalyst.catalog.SessionCatalog import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.execution.datasources.PruneFileSourcePartitions -import org.apache.spark.sql.execution.datasources.v2.PushDownOperatorsToDataSource import org.apache.spark.sql.execution.python.ExtractPythonUDFFromAggregate class SparkOptimizer( @@ -32,8 +31,7 @@ class SparkOptimizer( override def batches: Seq[Batch] = (preOptimizationBatches ++ super.batches :+ Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog)) :+ Batch("Extract Python UDF from Aggregate", Once, ExtractPythonUDFFromAggregate) :+ - Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions) :+ - Batch("Push down operators to data source scan", Once, PushDownOperatorsToDataSource)) ++ + Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions)) ++ postHocOptimizationBatches :+ Batch("User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index 90fb5a14c9fc..e08af218513f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -22,7 +22,6 @@ import scala.collection.JavaConverters._ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} -import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.sources.{DataSourceRegister, Filter} @@ -32,69 +31,27 @@ import org.apache.spark.sql.types.StructType case class DataSourceV2Relation( source: DataSourceV2, + output: Seq[AttributeReference], options: Map[String, String], - projection: Seq[AttributeReference], - filters: Option[Seq[Expression]] = None, userSpecifiedSchema: Option[StructType] = None) extends LeafNode with MultiInstanceRelation with DataSourceV2StringFormat { import DataSourceV2Relation._ - override def simpleString: String = "RelationV2 " + metadataString - - override lazy val schema: StructType = reader.readSchema() - - override lazy val output: Seq[AttributeReference] = { - // use the projection attributes to avoid assigning new ids. fields that are not projected - // will be assigned new ids, which is okay because they are not projected. - val attrMap = projection.map(a => a.name -> a).toMap - schema.map(f => attrMap.getOrElse(f.name, - AttributeReference(f.name, f.dataType, f.nullable, f.metadata)())) - } - - private lazy val v2Options: DataSourceOptions = makeV2Options(options) + override def pushedFilters: Seq[Expression] = Seq.empty - // postScanFilters: filters that need to be evaluated after the scan. - // pushedFilters: filters that will be pushed down and evaluated in the underlying data sources. - // Note: postScanFilters and pushedFilters can overlap, e.g. the parquet row group filter. - lazy val ( - reader: DataSourceReader, - postScanFilters: Seq[Expression], - pushedFilters: Seq[Expression]) = { - val newReader = userSpecifiedSchema match { - case Some(s) => - source.asReadSupportWithSchema.createReader(s, v2Options) - case _ => - source.asReadSupport.createReader(v2Options) - } - - DataSourceV2Relation.pushRequiredColumns(newReader, projection.toStructType) - - val (postScanFilters, pushedFilters) = filters match { - case Some(filterSeq) => - DataSourceV2Relation.pushFilters(newReader, filterSeq) - case _ => - (Nil, Nil) - } - logInfo(s"Post-Scan Filters: ${postScanFilters.mkString(",")}") - logInfo(s"Pushed Filters: ${pushedFilters.mkString(", ")}") - - (newReader, postScanFilters, pushedFilters) - } - - override def doCanonicalize(): LogicalPlan = { - val c = super.doCanonicalize().asInstanceOf[DataSourceV2Relation] + override def simpleString: String = "RelationV2 " + metadataString - // override output with canonicalized output to avoid attempting to configure a reader - val canonicalOutput: Seq[AttributeReference] = this.output - .map(a => QueryPlan.normalizeExprId(a, projection)) + lazy val v2Options: DataSourceOptions = makeV2Options(options) - new DataSourceV2Relation(c.source, c.options, c.projection) { - override lazy val output: Seq[AttributeReference] = canonicalOutput - } + def newReader: DataSourceReader = userSpecifiedSchema match { + case Some(userSchema) => + source.asReadSupportWithSchema.createReader(userSchema, v2Options) + case None => + source.asReadSupport.createReader(v2Options) } - override def computeStats(): Statistics = reader match { + override def computeStats(): Statistics = newReader match { case r: SupportsReportStatistics => Statistics(sizeInBytes = r.getStatistics.sizeInBytes().orElse(conf.defaultSizeInBytes)) case _ => @@ -102,9 +59,7 @@ case class DataSourceV2Relation( } override def newInstance(): DataSourceV2Relation = { - // projection is used to maintain id assignment. - // if projection is not set, use output so the copy is not equal to the original - copy(projection = projection.map(_.newInstance())) + copy(output = output.map(_.newInstance())) } } @@ -206,21 +161,27 @@ object DataSourceV2Relation { def create( source: DataSourceV2, options: Map[String, String], - filters: Option[Seq[Expression]] = None, userSpecifiedSchema: Option[StructType] = None): DataSourceV2Relation = { - val projection = schema(source, makeV2Options(options), userSpecifiedSchema).toAttributes - DataSourceV2Relation(source, options, projection, filters, userSpecifiedSchema) + val output = schema(source, makeV2Options(options), userSpecifiedSchema).toAttributes + DataSourceV2Relation(source, output, options, userSpecifiedSchema) } - private def pushRequiredColumns(reader: DataSourceReader, struct: StructType): Unit = { + def pushRequiredColumns( + relation: DataSourceV2Relation, + reader: DataSourceReader, + struct: StructType): Seq[AttributeReference] = { reader match { case projectionSupport: SupportsPushDownRequiredColumns => projectionSupport.pruneColumns(struct) + // return the output columns from the relation that were projected + val attrMap = relation.output.map(a => a.name -> a).toMap + projectionSupport.readSchema().map(f => attrMap(f.name)) case _ => + relation.output } } - private def pushFilters( + def pushFilters( reader: DataSourceReader, filters: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { reader match { @@ -248,7 +209,7 @@ object DataSourceV2Relation { // the data source cannot guarantee the rows returned can pass these filters. // As a result we must return it so Spark can plan an extra filter operator. val postScanFilters = - r.pushFilters(translatedFilterToExpr.keys.toArray).map(translatedFilterToExpr) + r.pushFilters(translatedFilterToExpr.keys.toArray).map(translatedFilterToExpr) // The filters which are marked as pushed to this data source val pushedFilters = r.pushedFilters().map(translatedFilterToExpr) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 1b7c639f10f9..8bf858c38d76 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -17,15 +17,56 @@ package org.apache.spark.sql.execution.datasources.v2 -import org.apache.spark.sql.Strategy +import org.apache.spark.sql.{execution, Strategy} +import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet} +import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.streaming.continuous.{WriteToContinuousDataSource, WriteToContinuousDataSourceExec} object DataSourceV2Strategy extends Strategy { override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case r: DataSourceV2Relation => - DataSourceV2ScanExec(r.output, r.source, r.options, r.pushedFilters, r.reader) :: Nil + case PhysicalOperation(project, filters, relation: DataSourceV2Relation) => + val projectSet = AttributeSet(project.flatMap(_.references)) + val filterSet = AttributeSet(filters.flatMap(_.references)) + + val projection = if (filterSet.subsetOf(projectSet) && + AttributeSet(relation.output) == projectSet) { + // When the required projection contains all of the filter columns and column pruning alone + // can produce the required projection, push the required projection. + // A final projection may still be needed if the data source produces a different column + // order or if it cannot prune all of the nested columns. + relation.output + } else { + // When there are filter columns not already in the required projection or when the required + // projection is more complicated than column pruning, base column pruning on the set of + // all columns needed by both. + (projectSet ++ filterSet).toSeq + } + + val reader = relation.newReader + + val output = DataSourceV2Relation.pushRequiredColumns(relation, reader, + projection.asInstanceOf[Seq[AttributeReference]].toStructType) + + val (postScanFilters, pushedFilters) = DataSourceV2Relation.pushFilters(reader, filters) + + logInfo(s"Post-Scan Filters: ${postScanFilters.mkString(",")}") + logInfo(s"Pushed Filters: ${pushedFilters.mkString(", ")}") + + val scan = DataSourceV2ScanExec( + output, relation.source, relation.options, pushedFilters, reader) + + val filter = postScanFilters.reduceLeftOption(And) + val withFilter = filter.map(execution.FilterExec(_, scan)).getOrElse(scan) + + val withProjection = if (withFilter.output != project) { + execution.ProjectExec(project, withFilter) + } else { + withFilter + } + + withProjection :: Nil case r: StreamingDataSourceV2Relation => DataSourceV2ScanExec(r.output, r.source, r.options, r.pushedFilters, r.reader) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala deleted file mode 100644 index e894f8afd676..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.datasources.v2 - -import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet} -import org.apache.spark.sql.catalyst.planning.PhysicalOperation -import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} -import org.apache.spark.sql.catalyst.rules.Rule - -object PushDownOperatorsToDataSource extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan match { - // PhysicalOperation guarantees that filters are deterministic; no need to check - case PhysicalOperation(project, filters, relation: DataSourceV2Relation) => - assert(relation.filters.isEmpty, "data source v2 should do push down only once.") - - val projectAttrs = project.map(_.toAttribute) - val projectSet = AttributeSet(project.flatMap(_.references)) - val filterSet = AttributeSet(filters.flatMap(_.references)) - - val projection = if (filterSet.subsetOf(projectSet) && - AttributeSet(projectAttrs) == projectSet) { - // When the required projection contains all of the filter columns and column pruning alone - // can produce the required projection, push the required projection. - // A final projection may still be needed if the data source produces a different column - // order or if it cannot prune all of the nested columns. - projectAttrs - } else { - // When there are filter columns not already in the required projection or when the required - // projection is more complicated than column pruning, base column pruning on the set of - // all columns needed by both. - (projectSet ++ filterSet).toSeq - } - - val newRelation = relation.copy( - projection = projection.asInstanceOf[Seq[AttributeReference]], - filters = Some(filters)) - - // Add a Filter for any filters that need to be evaluated after scan. - val postScanFilterCond = newRelation.postScanFilters.reduceLeftOption(And) - val filtered = postScanFilterCond.map(Filter(_, newRelation)).getOrElse(newRelation) - - // Add a Project to ensure the output matches the required projection - if (newRelation.output != projectAttrs) { - Project(project, filtered) - } else { - filtered - } - - case other => other.mapChildren(apply) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index e3d28388c547..ad95879d86f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.exchange +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.expressions._ @@ -227,9 +228,16 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { currentOrderOfKeys: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { val leftKeysBuffer = ArrayBuffer[Expression]() val rightKeysBuffer = ArrayBuffer[Expression]() + val pickedIndexes = mutable.Set[Int]() + val keysAndIndexes = currentOrderOfKeys.zipWithIndex expectedOrderOfKeys.foreach(expression => { - val index = currentOrderOfKeys.indexWhere(e => e.semanticEquals(expression)) + val index = keysAndIndexes.find { case (e, idx) => + // As we may have the same key used many times, we need to filter out its occurrence we + // have already used. + e.semanticEquals(expression) && !pickedIndexes.contains(idx) + }.map(_._2).get + pickedIndexes += index leftKeysBuffer.append(leftKeys(index)) rightKeysBuffer.append(rightKeys(index)) }) @@ -270,7 +278,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { * partitioning of the join nodes' children. */ private def reorderJoinPredicates(plan: SparkPlan): SparkPlan = { - plan.transformUp { + plan match { case BroadcastHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left, right) => val (reorderedLeftKeys, reorderedRightKeys) = @@ -288,6 +296,8 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { val (reorderedLeftKeys, reorderedRightKeys) = reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning) SortMergeJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, condition, left, right) + + case other => other } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala new file mode 100644 index 000000000000..a58773122922 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import java.io.File +import java.util.concurrent.TimeUnit +import java.util.concurrent.locks.ReentrantLock + +import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.api.python._ +import org.apache.spark.internal.Logging +import org.apache.spark.memory.TaskMemoryManager +import org.apache.spark.sql.ForeachWriter +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.{NextIterator, Utils} + +class PythonForeachWriter(func: PythonFunction, schema: StructType) + extends ForeachWriter[UnsafeRow] { + + private lazy val context = TaskContext.get() + private lazy val buffer = new PythonForeachWriter.UnsafeRowBuffer( + context.taskMemoryManager, new File(Utils.getLocalDir(SparkEnv.get.conf)), schema.fields.length) + private lazy val inputRowIterator = buffer.iterator + + private lazy val inputByteIterator = { + EvaluatePython.registerPicklers() + val objIterator = inputRowIterator.map { row => EvaluatePython.toJava(row, schema) } + new SerDeUtil.AutoBatchedPickler(objIterator) + } + + private lazy val pythonRunner = { + val conf = SparkEnv.get.conf + val bufferSize = conf.getInt("spark.buffer.size", 65536) + val reuseWorker = conf.getBoolean("spark.python.worker.reuse", true) + PythonRunner(func, bufferSize, reuseWorker) + } + + private lazy val outputIterator = + pythonRunner.compute(inputByteIterator, context.partitionId(), context) + + override def open(partitionId: Long, version: Long): Boolean = { + outputIterator // initialize everything + TaskContext.get.addTaskCompletionListener { _ => buffer.close() } + true + } + + override def process(value: UnsafeRow): Unit = { + buffer.add(value) + } + + override def close(errorOrNull: Throwable): Unit = { + buffer.allRowsAdded() + if (outputIterator.hasNext) outputIterator.next() // to throw python exception if there was one + } +} + +object PythonForeachWriter { + + /** + * A buffer that is designed for the sole purpose of buffering UnsafeRows in PythonForeachWriter. + * It is designed to be used with only 1 writer thread (i.e. JVM task thread) and only 1 reader + * thread (i.e. PythonRunner writing thread that reads from the buffer and writes to the Python + * worker stdin). Adds to the buffer are non-blocking, and reads through the buffer's iterator + * are blocking, that is, it blocks until new data is available or all data has been added. + * + * Internally, it uses a [[HybridRowQueue]] to buffer the rows in a practically unlimited queue + * across memory and local disk. However, HybridRowQueue is designed to be used only with + * EvalPythonExec where the reader is always behind the the writer, that is, the reader does not + * try to read n+1 rows if the writer has only written n rows at any point of time. This + * assumption is not true for PythonForeachWriter where rows may be added at a different rate as + * they are consumed by the python worker. Hence, to maintain the invariant of the reader being + * behind the writer while using HybridRowQueue, the buffer does the following + * - Keeps a count of the rows in the HybridRowQueue + * - Blocks the buffer's consuming iterator when the count is 0 so that the reader does not + * try to read more rows than what has been written. + * + * The implementation of the blocking iterator (ReentrantLock, Condition, etc.) has been borrowed + * from that of ArrayBlockingQueue. + */ + class UnsafeRowBuffer(taskMemoryManager: TaskMemoryManager, tempDir: File, numFields: Int) + extends Logging { + private val queue = HybridRowQueue(taskMemoryManager, tempDir, numFields) + private val lock = new ReentrantLock() + private val unblockRemove = lock.newCondition() + + // All of these are guarded by `lock` + private var count = 0L + private var allAdded = false + private var exception: Throwable = null + + val iterator = new NextIterator[UnsafeRow] { + override protected def getNext(): UnsafeRow = { + val row = remove() + if (row == null) finished = true + row + } + override protected def close(): Unit = { } + } + + def add(row: UnsafeRow): Unit = withLock { + assert(queue.add(row), s"Failed to add row to HybridRowQueue while sending data to Python" + + s"[count = $count, allAdded = $allAdded, exception = $exception]") + count += 1 + unblockRemove.signal() + logTrace(s"Added $row, $count left") + } + + private def remove(): UnsafeRow = withLock { + while (count == 0 && !allAdded && exception == null) { + unblockRemove.await(100, TimeUnit.MILLISECONDS) + } + + // If there was any error in the adding thread, then rethrow it in the removing thread + if (exception != null) throw exception + + if (count > 0) { + val row = queue.remove() + assert(row != null, "HybridRowQueue.remove() returned null " + + s"[count = $count, allAdded = $allAdded, exception = $exception]") + count -= 1 + logTrace(s"Removed $row, $count left") + row + } else { + null + } + } + + def allRowsAdded(): Unit = withLock { + allAdded = true + unblockRemove.signal() + } + + def close(): Unit = { queue.close() } + + private def withLock[T](f: => T): T = { + lock.lockInterruptibly() + try { f } catch { + case e: Throwable => + if (exception == null) exception = e + throw e + } finally { lock.unlock() } + } + } +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index b137f98045c5..7fa13c4aa2c0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ +import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader, SupportsScanUnsafeRow} import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2} import org.apache.spark.sql.streaming.OutputMode @@ -221,19 +222,60 @@ class MemoryStreamInputPartition(records: Array[UnsafeRow]) } /** A common trait for MemorySinks with methods used for testing */ -trait MemorySinkBase extends BaseStreamingSink { +trait MemorySinkBase extends BaseStreamingSink with Logging { def allData: Seq[Row] def latestBatchData: Seq[Row] def dataSinceBatch(sinceBatchId: Long): Seq[Row] def latestBatchId: Option[Long] + + /** + * Truncates the given rows to return at most maxRows rows. + * @param rows The data that may need to be truncated. + * @param batchLimit Number of rows to keep in this batch; the rest will be truncated + * @param sinkLimit Total number of rows kept in this sink, for logging purposes. + * @param batchId The ID of the batch that sent these rows, for logging purposes. + * @return Truncated rows. + */ + protected def truncateRowsIfNeeded( + rows: Array[Row], + batchLimit: Int, + sinkLimit: Int, + batchId: Long): Array[Row] = { + if (rows.length > batchLimit && batchLimit >= 0) { + logWarning(s"Truncating batch $batchId to $batchLimit rows because of sink limit $sinkLimit") + rows.take(batchLimit) + } else { + rows + } + } +} + +/** + * Companion object to MemorySinkBase. + */ +object MemorySinkBase { + val MAX_MEMORY_SINK_ROWS = "maxRows" + val MAX_MEMORY_SINK_ROWS_DEFAULT = -1 + + /** + * Gets the max number of rows a MemorySink should store. This number is based on the memory + * sink row limit option if it is set. If not, we use a large value so that data truncates + * rather than causing out of memory errors. + * @param options Options for writing from which we get the max rows option + * @return The maximum number of rows a memorySink should store. + */ + def getMemorySinkCapacity(options: DataSourceOptions): Int = { + val maxRows = options.getInt(MAX_MEMORY_SINK_ROWS, MAX_MEMORY_SINK_ROWS_DEFAULT) + if (maxRows >= 0) maxRows else Int.MaxValue - 10 + } } /** * A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit * tests and does not provide durability. */ -class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink - with MemorySinkBase with Logging { +class MemorySink(val schema: StructType, outputMode: OutputMode, options: DataSourceOptions) + extends Sink with MemorySinkBase with Logging { private case class AddedData(batchId: Long, data: Array[Row]) @@ -241,6 +283,12 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink @GuardedBy("this") private val batches = new ArrayBuffer[AddedData]() + /** The number of rows in this MemorySink. */ + private var numRows = 0 + + /** The capacity in rows of this sink. */ + val sinkCapacity: Int = MemorySinkBase.getMemorySinkCapacity(options) + /** Returns all rows that are stored in this [[Sink]]. */ def allData: Seq[Row] = synchronized { batches.flatMap(_.data) @@ -273,14 +321,23 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink logDebug(s"Committing batch $batchId to $this") outputMode match { case Append | Update => - val rows = AddedData(batchId, data.collect()) - synchronized { batches += rows } + var rowsToAdd = data.collect() + synchronized { + rowsToAdd = + truncateRowsIfNeeded(rowsToAdd, sinkCapacity - numRows, sinkCapacity, batchId) + val rows = AddedData(batchId, rowsToAdd) + batches += rows + numRows += rowsToAdd.length + } case Complete => - val rows = AddedData(batchId, data.collect()) + var rowsToAdd = data.collect() synchronized { + rowsToAdd = truncateRowsIfNeeded(rowsToAdd, sinkCapacity, sinkCapacity, batchId) + val rows = AddedData(batchId, rowsToAdd) batches.clear() batches += rows + numRows = rowsToAdd.length } case _ => @@ -294,6 +351,7 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink def clear(): Unit = synchronized { batches.clear() + numRows = 0 } override def toString(): String = "MemorySink" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala index df5d69d57e36..f677f25f116a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.execution.streaming.sources import org.apache.spark.sql.{Encoder, ForeachWriter, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.python.PythonForeachWriter import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamWriteSupport} import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory, SupportsWriteInternalRow, WriterCommitMessage} import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter @@ -31,9 +33,14 @@ import org.apache.spark.sql.types.StructType * [[ForeachWriter]]. * * @param writer The [[ForeachWriter]] to process all data. + * @param converter An object to convert internal rows to target type T. Either it can be + * a [[ExpressionEncoder]] or a direct converter function. * @tparam T The expected type of the sink. */ -case class ForeachWriterProvider[T: Encoder](writer: ForeachWriter[T]) extends StreamWriteSupport { +case class ForeachWriterProvider[T]( + writer: ForeachWriter[T], + converter: Either[ExpressionEncoder[T], InternalRow => T]) extends StreamWriteSupport { + override def createStreamWriter( queryId: String, schema: StructType, @@ -44,10 +51,16 @@ case class ForeachWriterProvider[T: Encoder](writer: ForeachWriter[T]) extends S override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} override def createInternalRowWriterFactory(): DataWriterFactory[InternalRow] = { - val encoder = encoderFor[T].resolveAndBind( - schema.toAttributes, - SparkSession.getActiveSession.get.sessionState.analyzer) - ForeachWriterFactory(writer, encoder) + val rowConverter: InternalRow => T = converter match { + case Left(enc) => + val boundEnc = enc.resolveAndBind( + schema.toAttributes, + SparkSession.getActiveSession.get.sessionState.analyzer) + boundEnc.fromRow + case Right(func) => + func + } + ForeachWriterFactory(writer, rowConverter) } override def toString: String = "ForeachSink" @@ -55,29 +68,44 @@ case class ForeachWriterProvider[T: Encoder](writer: ForeachWriter[T]) extends S } } -case class ForeachWriterFactory[T: Encoder]( +object ForeachWriterProvider { + def apply[T]( + writer: ForeachWriter[T], + encoder: ExpressionEncoder[T]): ForeachWriterProvider[_] = { + writer match { + case pythonWriter: PythonForeachWriter => + new ForeachWriterProvider[UnsafeRow]( + pythonWriter, Right((x: InternalRow) => x.asInstanceOf[UnsafeRow])) + case _ => + new ForeachWriterProvider[T](writer, Left(encoder)) + } + } +} + +case class ForeachWriterFactory[T]( writer: ForeachWriter[T], - encoder: ExpressionEncoder[T]) + rowConverter: InternalRow => T) extends DataWriterFactory[InternalRow] { override def createDataWriter( partitionId: Int, attemptNumber: Int, epochId: Long): ForeachDataWriter[T] = { - new ForeachDataWriter(writer, encoder, partitionId, epochId) + new ForeachDataWriter(writer, rowConverter, partitionId, epochId) } } /** * A [[DataWriter]] which writes data in this partition to a [[ForeachWriter]]. + * * @param writer The [[ForeachWriter]] to process all data. - * @param encoder An encoder which can convert [[InternalRow]] to the required type [[T]] + * @param rowConverter A function which can convert [[InternalRow]] to the required type [[T]] * @param partitionId * @param epochId * @tparam T The type expected by the writer. */ -class ForeachDataWriter[T : Encoder]( +class ForeachDataWriter[T]( writer: ForeachWriter[T], - encoder: ExpressionEncoder[T], + rowConverter: InternalRow => T, partitionId: Int, epochId: Long) extends DataWriter[InternalRow] { @@ -89,7 +117,7 @@ class ForeachDataWriter[T : Encoder]( if (!opened) return try { - writer.process(encoder.fromRow(record)) + writer.process(rowConverter(record)) } catch { case t: Throwable => writer.close(t) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala index fbff8db98711..b393c48baee8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala @@ -202,7 +202,7 @@ class RateStreamMicroBatchInputPartitionReader( rangeEnd: Long, localStartTimeMs: Long, relativeMsPerValue: Double) extends InputPartitionReader[Row] { - private var count = 0 + private var count: Long = 0 override def next(): Boolean = { rangeStart + partitionId + numPartitions * count < rangeEnd diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index 468313bfe8c3..47b482007822 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -46,7 +46,7 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB schema: StructType, mode: OutputMode, options: DataSourceOptions): StreamWriter = { - new MemoryStreamWriter(this, mode) + new MemoryStreamWriter(this, mode, options) } private case class AddedData(batchId: Long, data: Array[Row]) @@ -55,6 +55,9 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB @GuardedBy("this") private val batches = new ArrayBuffer[AddedData]() + /** The number of rows in this MemorySink. */ + private var numRows = 0 + /** Returns all rows that are stored in this [[Sink]]. */ def allData: Seq[Row] = synchronized { batches.flatMap(_.data) @@ -81,7 +84,11 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB }.mkString("\n") } - def write(batchId: Long, outputMode: OutputMode, newRows: Array[Row]): Unit = { + def write( + batchId: Long, + outputMode: OutputMode, + newRows: Array[Row], + sinkCapacity: Int): Unit = { val notCommitted = synchronized { latestBatchId.isEmpty || batchId > latestBatchId.get } @@ -89,14 +96,21 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB logDebug(s"Committing batch $batchId to $this") outputMode match { case Append | Update => - val rows = AddedData(batchId, newRows) - synchronized { batches += rows } + synchronized { + val rowsToAdd = + truncateRowsIfNeeded(newRows, sinkCapacity - numRows, sinkCapacity, batchId) + val rows = AddedData(batchId, rowsToAdd) + batches += rows + numRows += rowsToAdd.length + } case Complete => - val rows = AddedData(batchId, newRows) synchronized { + val rowsToAdd = truncateRowsIfNeeded(newRows, sinkCapacity, sinkCapacity, batchId) + val rows = AddedData(batchId, rowsToAdd) batches.clear() batches += rows + numRows = rowsToAdd.length } case _ => @@ -110,6 +124,7 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB def clear(): Unit = synchronized { batches.clear() + numRows = 0 } override def toString(): String = "MemorySinkV2" @@ -117,16 +132,22 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB case class MemoryWriterCommitMessage(partition: Int, data: Seq[Row]) extends WriterCommitMessage {} -class MemoryWriter(sink: MemorySinkV2, batchId: Long, outputMode: OutputMode) +class MemoryWriter( + sink: MemorySinkV2, + batchId: Long, + outputMode: OutputMode, + options: DataSourceOptions) extends DataSourceWriter with Logging { + val sinkCapacity: Int = MemorySinkBase.getMemorySinkCapacity(options) + override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode) def commit(messages: Array[WriterCommitMessage]): Unit = { val newRows = messages.flatMap { case message: MemoryWriterCommitMessage => message.data } - sink.write(batchId, outputMode, newRows) + sink.write(batchId, outputMode, newRows, sinkCapacity) } override def abort(messages: Array[WriterCommitMessage]): Unit = { @@ -134,16 +155,21 @@ class MemoryWriter(sink: MemorySinkV2, batchId: Long, outputMode: OutputMode) } } -class MemoryStreamWriter(val sink: MemorySinkV2, outputMode: OutputMode) +class MemoryStreamWriter( + val sink: MemorySinkV2, + outputMode: OutputMode, + options: DataSourceOptions) extends StreamWriter { + val sinkCapacity: Int = MemorySinkBase.getMemorySinkCapacity(options) + override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode) override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { val newRows = messages.flatMap { case message: MemoryWriterCommitMessage => message.data } - sink.write(epochId, outputMode, newRows) + sink.write(epochId, outputMode, newRows, sinkCapacity) } override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 87bd7b3b0f9c..8551058ec58c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3369,7 +3369,7 @@ object functions { val dataType = try { DataType.fromJson(schema) } catch { - case NonFatal(_) => StructType.fromDDL(schema) + case NonFatal(_) => DataType.fromDDL(schema) } from_json(e, dataType, options) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index effc1471e8e1..43e80e4e5423 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.execution.streaming.sources.{ForeachWriterProvider, MemoryPlanV2, MemorySinkV2} -import org.apache.spark.sql.sources.v2.StreamWriteSupport +import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamWriteSupport} /** * Interface used to write a streaming `Dataset` to external storage systems (e.g. file systems, @@ -249,7 +249,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { val r = Dataset.ofRows(df.sparkSession, new MemoryPlanV2(s, df.schema.toAttributes)) (s, r) case _ => - val s = new MemorySink(df.schema, outputMode) + val s = new MemorySink(df.schema, outputMode, new DataSourceOptions(extraOptions.asJava)) val r = Dataset.ofRows(df.sparkSession, new MemoryPlan(s)) (s, r) } @@ -269,7 +269,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { query } else if (source == "foreach") { assertNotPartitioned("foreach") - val sink = new ForeachWriterProvider[T](foreachWriter)(ds.exprEnc) + val sink = ForeachWriterProvider[T](foreachWriter, ds.exprEnc) df.sparkSession.sessionState.streamingQueryManager.startQuery( extraOptions.get("queryName"), extraOptions.get("checkpointLocation"), @@ -307,49 +307,9 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { } /** - * Starts the execution of the streaming query, which will continually send results to the given - * `ForeachWriter` as new data arrives. The `ForeachWriter` can be used to send the data - * generated by the `DataFrame`/`Dataset` to an external system. - * - * Scala example: - * {{{ - * datasetOfString.writeStream.foreach(new ForeachWriter[String] { - * - * def open(partitionId: Long, version: Long): Boolean = { - * // open connection - * } - * - * def process(record: String) = { - * // write string to connection - * } - * - * def close(errorOrNull: Throwable): Unit = { - * // close the connection - * } - * }).start() - * }}} - * - * Java example: - * {{{ - * datasetOfString.writeStream().foreach(new ForeachWriter() { - * - * @Override - * public boolean open(long partitionId, long version) { - * // open connection - * } - * - * @Override - * public void process(String value) { - * // write string to connection - * } - * - * @Override - * public void close(Throwable errorOrNull) { - * // close the connection - * } - * }).start(); - * }}} - * + * Sets the output of the streaming query to be processed using the provided writer object. + * object. See [[org.apache.spark.sql.ForeachWriter]] for more details on the lifecycle and + * semantics. * @since 2.0.0 */ def foreach(writer: ForeachWriter[T]): DataStreamWriter[T] = { diff --git a/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql index fea069eac4d4..dc15d13cd1dd 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql @@ -31,3 +31,7 @@ CREATE TEMPORARY VIEW jsonTable(jsonField, a) AS SELECT * FROM VALUES ('{"a": 1, SELECT json_tuple(jsonField, 'b', CAST(NULL AS STRING), a) FROM jsonTable; -- Clean up DROP VIEW IF EXISTS jsonTable; + +-- from_json - complex types +select from_json('{"a":1, "b":2}', 'map'); +select from_json('{"a":1, "b":"2"}', 'struct'); diff --git a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out index 14a69128ffb4..2b3288dc5a13 100644 --- a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 26 +-- Number of queries: 28 -- !query 0 @@ -258,3 +258,19 @@ DROP VIEW IF EXISTS jsonTable struct<> -- !query 25 output + + +-- !query 26 +select from_json('{"a":1, "b":2}', 'map') +-- !query 26 schema +struct> +-- !query 26 output +{"a":1,"b":2} + + +-- !query 27 +select from_json('{"a":1, "b":"2"}', 'struct') +-- !query 27 schema +struct> +-- !query 27 output +{"a":1,"b":"2"} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index d477d78dc14e..562a756b50ec 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1466,6 +1466,14 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val ds = Seq[(Option[Int], Option[Int])]((Some(1), None)).toDS() intercept[NullPointerException](ds.as[(Int, Int)].collect()) } + + test("SPARK-24571: filtering of string values by char literal") { + val df = Seq("Amsterdam", "San Francisco", "X").toDF("city") + checkAnswer(df.where('city === 'X'), Seq(Row("X"))) + checkAnswer( + df.where($"city".contains(new java.lang.Character('A'))), + Seq(Row("Amsterdam"))) + } } case class TestDataUnion(x: Int, y: Int, z: Int) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 8fa747465cb1..44767dfc9249 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -882,4 +882,15 @@ class JoinSuite extends QueryTest with SharedSQLContext { checkAnswer(df, Row(3, 8, 7, 2) :: Row(3, 8, 4, 2) :: Nil) } } + + test("SPARK-24495: Join may return wrong result when having duplicated equal-join keys") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1", + SQLConf.CONSTRAINT_PROPAGATION_ENABLED.key -> "false", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + val df1 = spark.range(0, 100, 1, 2) + val df2 = spark.range(100).select($"id".as("b1"), (- $"id").as("b2")) + val res = df1.join(df2, $"id" === $"b1" && $"id" === $"b2").select($"b1", $"b2", $"id") + checkAnswer(res, Row(0, 0, 0)) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index 055e1fc5640f..7bf17cbcd9c9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -354,8 +354,8 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { test("SPARK-24027: from_json - map>") { val in = Seq("""{"a": {"b": 1}}""").toDS() - val schema = MapType(StringType, MapType(StringType, IntegerType)) - val out = in.select(from_json($"value", schema)) + val schema = "map>" + val out = in.select(from_json($"value", schema, Map.empty[String, String])) checkAnswer(out, Row(Map("a" -> Map("b" -> 1)))) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index ed0ff1be476c..37d468739c61 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -680,6 +680,23 @@ class PlannerSuite extends SharedSQLContext { assert(rangeExecInZeroPartition.head.outputPartitioning == UnknownPartitioning(0)) } + test("SPARK-24495: EnsureRequirements can return wrong plan when reusing the same key in join") { + val plan1 = DummySparkPlan(outputOrdering = Seq(orderingA), + outputPartitioning = HashPartitioning(exprA :: exprA :: Nil, 5)) + val plan2 = DummySparkPlan(outputOrdering = Seq(orderingB), + outputPartitioning = HashPartitioning(exprB :: Nil, 5)) + val smjExec = SortMergeJoinExec( + exprA :: exprA :: Nil, exprB :: exprC :: Nil, Inner, None, plan1, plan2) + + val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(smjExec) + outputPlan match { + case SortMergeJoinExec(leftKeys, rightKeys, _, _, _, _) => + assert(leftKeys == Seq(exprA, exprA)) + assert(rightKeys == Seq(exprB, exprC)) + case _ => fail() + } + } + test("SPARK-24500: create union with stream of children") { val df = Union(Stream( Range(1, 1, 1, 1), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonForeachWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonForeachWriterSuite.scala new file mode 100644 index 000000000000..07e603477012 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonForeachWriterSuite.scala @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import scala.collection.mutable.ArrayBuffer + +import org.scalatest.concurrent.Eventually +import org.scalatest.time.SpanSugar._ + +import org.apache.spark._ +import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager} +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection} +import org.apache.spark.sql.execution.python.PythonForeachWriter.UnsafeRowBuffer +import org.apache.spark.sql.types.{DataType, IntegerType} +import org.apache.spark.util.Utils + +class PythonForeachWriterSuite extends SparkFunSuite with Eventually { + + testWithBuffer("UnsafeRowBuffer: iterator blocks when no data is available") { b => + b.assertIteratorBlocked() + + b.add(Seq(1)) + b.assertOutput(Seq(1)) + b.assertIteratorBlocked() + + b.add(2 to 100) + b.assertOutput(1 to 100) + b.assertIteratorBlocked() + } + + testWithBuffer("UnsafeRowBuffer: iterator unblocks when all data added") { b => + b.assertIteratorBlocked() + b.add(Seq(1)) + b.assertIteratorBlocked() + + b.allAdded() + b.assertThreadTerminated() + b.assertOutput(Seq(1)) + } + + testWithBuffer( + "UnsafeRowBuffer: handles more data than memory", + memBytes = 5, + sleepPerRowReadMs = 1) { b => + + b.assertIteratorBlocked() + b.add(1 to 2000) + b.assertOutput(1 to 2000) + } + + def testWithBuffer( + name: String, + memBytes: Long = 4 << 10, + sleepPerRowReadMs: Int = 0 + )(f: BufferTester => Unit): Unit = { + + test(name) { + var tester: BufferTester = null + try { + tester = new BufferTester(memBytes, sleepPerRowReadMs) + f(tester) + } finally { + if (tester == null) tester.close() + } + } + } + + + class BufferTester(memBytes: Long, sleepPerRowReadMs: Int) { + private val buffer = { + val mem = new TestMemoryManager(new SparkConf()) + mem.limit(memBytes) + val taskM = new TaskMemoryManager(mem, 0) + new UnsafeRowBuffer(taskM, Utils.createTempDir(), 1) + } + private val iterator = buffer.iterator + private val outputBuffer = new ArrayBuffer[Int] + private val testTimeout = timeout(20.seconds) + private val intProj = UnsafeProjection.create(Array[DataType](IntegerType)) + private val thread = new Thread() { + override def run(): Unit = { + while (iterator.hasNext) { + outputBuffer.synchronized { + outputBuffer += iterator.next().getInt(0) + } + Thread.sleep(sleepPerRowReadMs) + } + } + } + thread.start() + + def add(ints: Seq[Int]): Unit = { + ints.foreach { i => buffer.add(intProj.apply(new GenericInternalRow(Array[Any](i)))) } + } + + def allAdded(): Unit = { buffer.allRowsAdded() } + + def assertOutput(expectedOutput: Seq[Int]): Unit = { + eventually(testTimeout) { + val output = outputBuffer.synchronized { outputBuffer.toArray }.toSeq + assert(output == expectedOutput) + } + } + + def assertIteratorBlocked(): Unit = { + import Thread.State._ + eventually(testTimeout) { + assert(thread.isAlive) + assert(thread.getState == TIMED_WAITING || thread.getState == WAITING) + } + } + + def assertThreadTerminated(): Unit = { + eventually(testTimeout) { assert(!thread.isAlive) } + } + + def close(): Unit = { + thread.interrupt() + thread.join() + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala index 3bc36ce55d90..b2fd6ba27ebb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala @@ -17,11 +17,13 @@ package org.apache.spark.sql.execution.streaming +import scala.collection.JavaConverters._ import scala.language.implicitConversions import org.scalatest.BeforeAndAfter import org.apache.spark.sql._ +import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.streaming.{OutputMode, StreamTest} import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.apache.spark.util.Utils @@ -36,7 +38,7 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { test("directly add data in Append output mode") { implicit val schema = new StructType().add(new StructField("value", IntegerType)) - val sink = new MemorySink(schema, OutputMode.Append) + val sink = new MemorySink(schema, OutputMode.Append, DataSourceOptions.empty()) // Before adding data, check output assert(sink.latestBatchId === None) @@ -68,9 +70,35 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { checkAnswer(sink.allData, 1 to 9) } + test("directly add data in Append output mode with row limit") { + implicit val schema = new StructType().add(new StructField("value", IntegerType)) + + var optionsMap = new scala.collection.mutable.HashMap[String, String] + optionsMap.put(MemorySinkBase.MAX_MEMORY_SINK_ROWS, 5.toString()) + var options = new DataSourceOptions(optionsMap.toMap.asJava) + val sink = new MemorySink(schema, OutputMode.Append, options) + + // Before adding data, check output + assert(sink.latestBatchId === None) + checkAnswer(sink.latestBatchData, Seq.empty) + checkAnswer(sink.allData, Seq.empty) + + // Add batch 0 and check outputs + sink.addBatch(0, 1 to 3) + assert(sink.latestBatchId === Some(0)) + checkAnswer(sink.latestBatchData, 1 to 3) + checkAnswer(sink.allData, 1 to 3) + + // Add batch 1 and check outputs + sink.addBatch(1, 4 to 6) + assert(sink.latestBatchId === Some(1)) + checkAnswer(sink.latestBatchData, 4 to 5) + checkAnswer(sink.allData, 1 to 5) // new data should not go over the limit + } + test("directly add data in Update output mode") { implicit val schema = new StructType().add(new StructField("value", IntegerType)) - val sink = new MemorySink(schema, OutputMode.Update) + val sink = new MemorySink(schema, OutputMode.Update, DataSourceOptions.empty()) // Before adding data, check output assert(sink.latestBatchId === None) @@ -104,7 +132,7 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { test("directly add data in Complete output mode") { implicit val schema = new StructType().add(new StructField("value", IntegerType)) - val sink = new MemorySink(schema, OutputMode.Complete) + val sink = new MemorySink(schema, OutputMode.Complete, DataSourceOptions.empty()) // Before adding data, check output assert(sink.latestBatchId === None) @@ -136,6 +164,32 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { checkAnswer(sink.allData, 7 to 9) } + test("directly add data in Complete output mode with row limit") { + implicit val schema = new StructType().add(new StructField("value", IntegerType)) + + var optionsMap = new scala.collection.mutable.HashMap[String, String] + optionsMap.put(MemorySinkBase.MAX_MEMORY_SINK_ROWS, 5.toString()) + var options = new DataSourceOptions(optionsMap.toMap.asJava) + val sink = new MemorySink(schema, OutputMode.Complete, options) + + // Before adding data, check output + assert(sink.latestBatchId === None) + checkAnswer(sink.latestBatchData, Seq.empty) + checkAnswer(sink.allData, Seq.empty) + + // Add batch 0 and check outputs + sink.addBatch(0, 1 to 3) + assert(sink.latestBatchId === Some(0)) + checkAnswer(sink.latestBatchData, 1 to 3) + checkAnswer(sink.allData, 1 to 3) + + // Add batch 1 and check outputs + sink.addBatch(1, 4 to 10) + assert(sink.latestBatchId === Some(1)) + checkAnswer(sink.latestBatchData, 4 to 8) + checkAnswer(sink.allData, 4 to 8) // new data should replace old data + } + test("registering as a table in Append output mode") { val input = MemoryStream[Int] @@ -211,7 +265,7 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { test("MemoryPlan statistics") { implicit val schema = new StructType().add(new StructField("value", IntegerType)) - val sink = new MemorySink(schema, OutputMode.Append) + val sink = new MemorySink(schema, OutputMode.Append, DataSourceOptions.empty()) val plan = new MemoryPlan(sink) // Before adding data, check output diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala index 9be22d94b565..e539510e1575 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala @@ -17,11 +17,16 @@ package org.apache.spark.sql.execution.streaming +import scala.collection.JavaConverters._ + import org.scalatest.BeforeAndAfter import org.apache.spark.sql.Row import org.apache.spark.sql.execution.streaming.sources._ +import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.streaming.{OutputMode, StreamTest} +import org.apache.spark.sql.types.IntegerType +import org.apache.spark.sql.types.StructType class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { test("data writer") { @@ -40,7 +45,7 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { test("continuous writer") { val sink = new MemorySinkV2 - val writer = new MemoryStreamWriter(sink, OutputMode.Append()) + val writer = new MemoryStreamWriter(sink, OutputMode.Append(), DataSourceOptions.empty()) writer.commit(0, Array( MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), @@ -62,7 +67,7 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { test("microbatch writer") { val sink = new MemorySinkV2 - new MemoryWriter(sink, 0, OutputMode.Append()).commit( + new MemoryWriter(sink, 0, OutputMode.Append(), DataSourceOptions.empty()).commit( Array( MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), MemoryWriterCommitMessage(1, Seq(Row(3), Row(4))), @@ -70,7 +75,7 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { )) assert(sink.latestBatchId.contains(0)) assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7)) - new MemoryWriter(sink, 19, OutputMode.Append()).commit( + new MemoryWriter(sink, 19, OutputMode.Append(), DataSourceOptions.empty()).commit( Array( MemoryWriterCommitMessage(3, Seq(Row(11), Row(22))), MemoryWriterCommitMessage(0, Seq(Row(33))) @@ -80,4 +85,73 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7, 11, 22, 33)) } + + test("continuous writer with row limit") { + val sink = new MemorySinkV2 + val optionsMap = new scala.collection.mutable.HashMap[String, String] + optionsMap.put(MemorySinkBase.MAX_MEMORY_SINK_ROWS, 7.toString()) + val options = new DataSourceOptions(optionsMap.toMap.asJava) + val appendWriter = new MemoryStreamWriter(sink, OutputMode.Append(), options) + appendWriter.commit(0, Array( + MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), + MemoryWriterCommitMessage(1, Seq(Row(3), Row(4))), + MemoryWriterCommitMessage(2, Seq(Row(6), Row(7))))) + assert(sink.latestBatchId.contains(0)) + assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7)) + appendWriter.commit(19, Array( + MemoryWriterCommitMessage(3, Seq(Row(11), Row(22))), + MemoryWriterCommitMessage(0, Seq(Row(33))))) + assert(sink.latestBatchId.contains(19)) + assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(11)) + + assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7, 11)) + + val completeWriter = new MemoryStreamWriter(sink, OutputMode.Complete(), options) + completeWriter.commit(20, Array( + MemoryWriterCommitMessage(4, Seq(Row(11), Row(22))), + MemoryWriterCommitMessage(5, Seq(Row(33))))) + assert(sink.latestBatchId.contains(20)) + assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(11, 22, 33)) + completeWriter.commit(21, Array( + MemoryWriterCommitMessage(0, Seq(Row(1), Row(2), Row(3))), + MemoryWriterCommitMessage(1, Seq(Row(4), Row(5), Row(6))), + MemoryWriterCommitMessage(2, Seq(Row(7), Row(8), Row(9))))) + assert(sink.latestBatchId.contains(21)) + assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 5, 6, 7)) + + assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 5, 6, 7)) + } + + test("microbatch writer with row limit") { + val sink = new MemorySinkV2 + val optionsMap = new scala.collection.mutable.HashMap[String, String] + optionsMap.put(MemorySinkBase.MAX_MEMORY_SINK_ROWS, 5.toString()) + val options = new DataSourceOptions(optionsMap.toMap.asJava) + + new MemoryWriter(sink, 25, OutputMode.Append(), options).commit(Array( + MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), + MemoryWriterCommitMessage(1, Seq(Row(3), Row(4))))) + assert(sink.latestBatchId.contains(25)) + assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4)) + assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4)) + new MemoryWriter(sink, 26, OutputMode.Append(), options).commit(Array( + MemoryWriterCommitMessage(2, Seq(Row(5), Row(6))), + MemoryWriterCommitMessage(3, Seq(Row(7), Row(8))))) + assert(sink.latestBatchId.contains(26)) + assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(5)) + assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 5)) + + new MemoryWriter(sink, 27, OutputMode.Complete(), options).commit(Array( + MemoryWriterCommitMessage(4, Seq(Row(9), Row(10))), + MemoryWriterCommitMessage(5, Seq(Row(11), Row(12))))) + assert(sink.latestBatchId.contains(27)) + assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(9, 10, 11, 12)) + assert(sink.allData.map(_.getInt(0)).sorted == Seq(9, 10, 11, 12)) + new MemoryWriter(sink, 28, OutputMode.Complete(), options).commit(Array( + MemoryWriterCommitMessage(4, Seq(Row(13), Row(14), Row(15))), + MemoryWriterCommitMessage(5, Seq(Row(16), Row(17), Row(18))))) + assert(sink.latestBatchId.contains(28)) + assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(13, 14, 15, 16, 17)) + assert(sink.allData.map(_.getInt(0)).sorted == Seq(13, 14, 15, 16, 17)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index 505a3f3465c0..e96cd4500458 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -323,21 +323,22 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } test("SPARK-23315: get output from canonicalized data source v2 related plans") { - def checkCanonicalizedOutput(df: DataFrame, numOutput: Int): Unit = { + def checkCanonicalizedOutput( + df: DataFrame, logicalNumOutput: Int, physicalNumOutput: Int): Unit = { val logical = df.queryExecution.optimizedPlan.collect { case d: DataSourceV2Relation => d }.head - assert(logical.canonicalized.output.length == numOutput) + assert(logical.canonicalized.output.length == logicalNumOutput) val physical = df.queryExecution.executedPlan.collect { case d: DataSourceV2ScanExec => d }.head - assert(physical.canonicalized.output.length == numOutput) + assert(physical.canonicalized.output.length == physicalNumOutput) } val df = spark.read.format(classOf[AdvancedDataSourceV2].getName).load() - checkCanonicalizedOutput(df, 2) - checkCanonicalizedOutput(df.select('i), 1) + checkCanonicalizedOutput(df, 2, 2) + checkCanonicalizedOutput(df.select('i), 2, 1) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 4c3fd58cb2e4..e41b4534ed51 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -45,6 +45,7 @@ import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, import org.apache.spark.sql.execution.streaming.sources.MemorySinkV2 import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.streaming.StreamingQueryListener._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.{Clock, SystemClock, Utils} @@ -337,7 +338,8 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be var currentStream: StreamExecution = null var lastStream: StreamExecution = null val awaiting = new mutable.HashMap[Int, Offset]() // source index -> offset to wait for - val sink = if (useV2Sink) new MemorySinkV2 else new MemorySink(stream.schema, outputMode) + val sink = if (useV2Sink) new MemorySinkV2 + else new MemorySink(stream.schema, outputMode, DataSourceOptions.empty()) val resetConfValues = mutable.Map[String, Option[String]]() val defaultCheckpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 130e258e78ca..8620f3f6d99f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -342,7 +342,7 @@ private[client] class Shim_v0_12 extends Shim with Logging { } override def getMetastoreClientConnectRetryDelayMillis(conf: HiveConf): Long = { - conf.getIntVar(HiveConf.ConfVars.METASTORE_CLIENT_CONNECT_RETRY_DELAY) * 1000 + conf.getIntVar(HiveConf.ConfVars.METASTORE_CLIENT_CONNECT_RETRY_DELAY) * 1000L } override def loadPartition( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala index 6f904c937348..514921875f1f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala @@ -195,7 +195,7 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { object PROCESS_TABLES extends QueryTest with SQLTestUtils { // Tests the latest version of every release line. - val testingVersions = Seq("2.0.2", "2.1.2", "2.2.1", "2.3.0") + val testingVersions = Seq("2.0.2", "2.1.2", "2.2.1", "2.3.1") protected var spark: SparkSession = _ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala index 9d1b82a6341b..25e71258b936 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala @@ -49,7 +49,7 @@ private[spark] class StreamingTab(val ssc: StreamingContext) def detach() { getSparkUI(ssc).detachTab(this) - getSparkUI(ssc).removeStaticHandler("/static/streaming") + getSparkUI(ssc).detachHandler("/static/streaming") } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala index ab7c8558321c..2e8599026ea1 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala @@ -222,7 +222,7 @@ private[streaming] class FileBasedWriteAheadLog( pastLogs += LogInfo(currentLogWriterStartTime, currentLogWriterStopTime, _) } currentLogWriterStartTime = currentTime - currentLogWriterStopTime = currentTime + (rollingIntervalSecs * 1000) + currentLogWriterStopTime = currentTime + (rollingIntervalSecs * 1000L) val newLogPath = new Path(logDirectory, timeToLogFile(currentLogWriterStartTime, currentLogWriterStopTime)) currentLogPath = Some(newLogPath.toString)