diff --git a/assembly/pom.xml b/assembly/pom.xml index 7ce30179e9ca2..82a5985504b4e 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -29,9 +29,12 @@ spark-assembly_2.10 Spark Project Assembly http://spark.apache.org/ + pom - ${project.build.directory}/scala-${scala.binary.version}/${project.artifactId}-${project.version}-hadoop${hadoop.version}.jar + scala-${scala.binary.version} + ${project.artifactId}-${project.version}-hadoop${hadoop.version}.jar + ${project.build.directory}/${spark.jar.dir}/${spark.jar.basename} spark /usr/share/spark root diff --git a/assembly/src/main/assembly/assembly.xml b/assembly/src/main/assembly/assembly.xml index 6af383db65d47..711156337b7c3 100644 --- a/assembly/src/main/assembly/assembly.xml +++ b/assembly/src/main/assembly/assembly.xml @@ -55,6 +55,15 @@ **/* + + + ${project.parent.basedir}/assembly/target/${spark.jar.dir} + + / + + ${spark.jar.basename} + + @@ -75,6 +84,8 @@ org.apache.hadoop:*:jar org.apache.spark:*:jar + org.apache.zookeeper:*:jar + org.apache.avro:*:jar diff --git a/bagel/pom.xml b/bagel/pom.xml index 41aacbd88a7d7..142f75c5d2c64 100644 --- a/bagel/pom.xml +++ b/bagel/pom.xml @@ -37,10 +37,10 @@ a Hadoop 0.23.X issue --> yarn-alpha - - org.apache.avro - avro - + + org.apache.avro + avro + diff --git a/bin/spark-class b/bin/spark-class index c4225a392d6da..229ae2cebbab3 100755 --- a/bin/spark-class +++ b/bin/spark-class @@ -40,34 +40,46 @@ if [ -z "$1" ]; then exit 1 fi -# If this is a standalone cluster daemon, reset SPARK_JAVA_OPTS and SPARK_MEM to reasonable -# values for that; it doesn't need a lot -if [ "$1" = "org.apache.spark.deploy.master.Master" -o "$1" = "org.apache.spark.deploy.worker.Worker" ]; then - SPARK_MEM=${SPARK_DAEMON_MEMORY:-512m} - SPARK_DAEMON_JAVA_OPTS="$SPARK_DAEMON_JAVA_OPTS -Dspark.akka.logLifecycleEvents=true" - # Do not overwrite SPARK_JAVA_OPTS environment variable in this script - OUR_JAVA_OPTS="$SPARK_DAEMON_JAVA_OPTS" # Empty by default -else - OUR_JAVA_OPTS="$SPARK_JAVA_OPTS" +if [ -n "$SPARK_MEM" ]; then + echo "Warning: SPARK_MEM is deprecated, please use a more specific config option" + echo "(e.g., spark.executor.memory or SPARK_DRIVER_MEMORY)." fi +# Use SPARK_MEM or 512m as the default memory, to be overridden by specific options +DEFAULT_MEM=${SPARK_MEM:-512m} + +SPARK_DAEMON_JAVA_OPTS="$SPARK_DAEMON_JAVA_OPTS -Dspark.akka.logLifecycleEvents=true" -# Add java opts for master, worker, executor. The opts maybe null +# Add java opts and memory settings for master, worker, executors, and repl. case "$1" in + # Master and Worker use SPARK_DAEMON_JAVA_OPTS (and specific opts) + SPARK_DAEMON_MEMORY. 'org.apache.spark.deploy.master.Master') - OUR_JAVA_OPTS="$OUR_JAVA_OPTS $SPARK_MASTER_OPTS" + OUR_JAVA_OPTS="$SPARK_DAEMON_JAVA_OPTS $SPARK_MASTER_OPTS" + OUR_JAVA_MEM=${SPARK_DAEMON_MEMORY:-$DEFAULT_MEM} ;; 'org.apache.spark.deploy.worker.Worker') - OUR_JAVA_OPTS="$OUR_JAVA_OPTS $SPARK_WORKER_OPTS" + OUR_JAVA_OPTS="$SPARK_DAEMON_JAVA_OPTS $SPARK_WORKER_OPTS" + OUR_JAVA_MEM=${SPARK_DAEMON_MEMORY:-$DEFAULT_MEM} ;; + + # Executors use SPARK_JAVA_OPTS + SPARK_EXECUTOR_MEMORY. 'org.apache.spark.executor.CoarseGrainedExecutorBackend') - OUR_JAVA_OPTS="$OUR_JAVA_OPTS $SPARK_EXECUTOR_OPTS" + OUR_JAVA_OPTS="$SPARK_JAVA_OPTS $SPARK_EXECUTOR_OPTS" + OUR_JAVA_MEM=${SPARK_EXECUTOR_MEMORY:-$DEFAULT_MEM} ;; 'org.apache.spark.executor.MesosExecutorBackend') - OUR_JAVA_OPTS="$OUR_JAVA_OPTS $SPARK_EXECUTOR_OPTS" + OUR_JAVA_OPTS="$SPARK_JAVA_OPTS $SPARK_EXECUTOR_OPTS" + OUR_JAVA_MEM=${SPARK_EXECUTOR_MEMORY:-$DEFAULT_MEM} ;; + + # All drivers use SPARK_JAVA_OPTS + SPARK_DRIVER_MEMORY. The repl also uses SPARK_REPL_OPTS. 'org.apache.spark.repl.Main') - OUR_JAVA_OPTS="$OUR_JAVA_OPTS $SPARK_REPL_OPTS" + OUR_JAVA_OPTS="$SPARK_JAVA_OPTS $SPARK_REPL_OPTS" + OUR_JAVA_MEM=${SPARK_DRIVER_MEMORY:-$DEFAULT_MEM} + ;; + *) + OUR_JAVA_OPTS="$SPARK_JAVA_OPTS" + OUR_JAVA_MEM=${SPARK_DRIVER_MEMORY:-$DEFAULT_MEM} ;; esac @@ -83,14 +95,10 @@ else fi fi -# Set SPARK_MEM if it isn't already set since we also use it for this process -SPARK_MEM=${SPARK_MEM:-512m} -export SPARK_MEM - # Set JAVA_OPTS to be able to load native libraries and to set heap size JAVA_OPTS="$OUR_JAVA_OPTS" JAVA_OPTS="$JAVA_OPTS -Djava.library.path=$SPARK_LIBRARY_PATH" -JAVA_OPTS="$JAVA_OPTS -Xms$SPARK_MEM -Xmx$SPARK_MEM" +JAVA_OPTS="$JAVA_OPTS -Xms$OUR_JAVA_MEM -Xmx$OUR_JAVA_MEM" # Load extra JAVA_OPTS from conf/java-opts, if it exists if [ -e "$FWDIR/conf/java-opts" ] ; then JAVA_OPTS="$JAVA_OPTS `cat $FWDIR/conf/java-opts`" diff --git a/bin/spark-class2.cmd b/bin/spark-class2.cmd index 80818c78ec24b..f488cfdbeceb6 100755 --- a/bin/spark-class2.cmd +++ b/bin/spark-class2.cmd @@ -34,22 +34,45 @@ if not "x%1"=="x" goto arg_given goto exit :arg_given -set RUNNING_DAEMON=0 -if "%1"=="spark.deploy.master.Master" set RUNNING_DAEMON=1 -if "%1"=="spark.deploy.worker.Worker" set RUNNING_DAEMON=1 -if "x%SPARK_DAEMON_MEMORY%" == "x" set SPARK_DAEMON_MEMORY=512m +if not "x%SPARK_MEM%"=="x" ( + echo Warning: SPARK_MEM is deprecated, please use a more specific config option + echo e.g., spark.executor.memory or SPARK_DRIVER_MEMORY. +) + +rem Use SPARK_MEM or 512m as the default memory, to be overridden by specific options +set OUR_JAVA_MEM=%SPARK_MEM% +if "x%OUR_JAVA_MEM%"=="x" set OUR_JAVA_MEM=512m + set SPARK_DAEMON_JAVA_OPTS=%SPARK_DAEMON_JAVA_OPTS% -Dspark.akka.logLifecycleEvents=true -if "%RUNNING_DAEMON%"=="1" set SPARK_MEM=%SPARK_DAEMON_MEMORY% -rem Do not overwrite SPARK_JAVA_OPTS environment variable in this script -if "%RUNNING_DAEMON%"=="0" set OUR_JAVA_OPTS=%SPARK_JAVA_OPTS% -if "%RUNNING_DAEMON%"=="1" set OUR_JAVA_OPTS=%SPARK_DAEMON_JAVA_OPTS% -rem Figure out how much memory to use per executor and set it as an environment -rem variable so that our process sees it and can report it to Mesos -if "x%SPARK_MEM%"=="x" set SPARK_MEM=512m +rem Add java opts and memory settings for master, worker, executors, and repl. +rem Master and Worker use SPARK_DAEMON_JAVA_OPTS (and specific opts) + SPARK_DAEMON_MEMORY. +if "%1"=="org.apache.spark.deploy.master.Master" ( + set OUR_JAVA_OPTS=%SPARK_DAEMON_JAVA_OPTS% %SPARK_MASTER_OPTS% + if not "x%SPARK_DAEMON_MEMORY%"=="x" set OUR_JAVA_MEM=%SPARK_DAEMON_MEMORY% +) else if "%1"=="org.apache.spark.deploy.worker.Worker" ( + set OUR_JAVA_OPTS=%SPARK_DAEMON_JAVA_OPTS% %SPARK_WORKER_OPTS% + if not "x%SPARK_DAEMON_MEMORY%"=="x" set OUR_JAVA_MEM=%SPARK_DAEMON_MEMORY% + +rem Executors use SPARK_JAVA_OPTS + SPARK_EXECUTOR_MEMORY. +) else if "%1"=="org.apache.spark.executor.CoarseGrainedExecutorBackend" ( + set OUR_JAVA_OPTS=%SPARK_JAVA_OPTS% %SPARK_EXECUTOR_OPTS% + if not "x%SPARK_EXECUTOR_MEMORY%"=="x" set OUR_JAVA_MEM=%SPARK_EXECUTOR_MEMORY% +) else if "%1"=="org.apache.spark.executor.MesosExecutorBackend" ( + set OUR_JAVA_OPTS=%SPARK_JAVA_OPTS% %SPARK_EXECUTOR_OPTS% + if not "x%SPARK_EXECUTOR_MEMORY%"=="x" set OUR_JAVA_MEM=%SPARK_EXECUTOR_MEMORY% + +rem All drivers use SPARK_JAVA_OPTS + SPARK_DRIVER_MEMORY. The repl also uses SPARK_REPL_OPTS. +) else if "%1"=="org.apache.spark.repl.Main" ( + set OUR_JAVA_OPTS=%SPARK_JAVA_OPTS% %SPARK_REPL_OPTS% + if not "x%SPARK_DRIVER_MEMORY%"=="x" set OUR_JAVA_MEM=%SPARK_DRIVER_MEMORY% +) else ( + set OUR_JAVA_OPTS=%SPARK_JAVA_OPTS% + if not "x%SPARK_DRIVER_MEMORY%"=="x" set OUR_JAVA_MEM=%SPARK_DRIVER_MEMORY% +) rem Set JAVA_OPTS to be able to load native libraries and to set heap size -set JAVA_OPTS=%OUR_JAVA_OPTS% -Djava.library.path=%SPARK_LIBRARY_PATH% -Xms%SPARK_MEM% -Xmx%SPARK_MEM% +set JAVA_OPTS=%OUR_JAVA_OPTS% -Djava.library.path=%SPARK_LIBRARY_PATH% -Xms%OUR_JAVA_MEM% -Xmx%OUR_JAVA_MEM% rem Attention: when changing the way the JAVA_OPTS are assembled, the change must be reflected in ExecutorRunner.scala! rem Test whether the user has built Spark diff --git a/bin/spark-shell b/bin/spark-shell index 2bff06cf70051..a0b63f1be34e6 100755 --- a/bin/spark-shell +++ b/bin/spark-shell @@ -30,71 +30,367 @@ esac # Enter posix mode for bash set -o posix -CORE_PATTERN="^[0-9]+$" -MEM_PATTERN="^[0-9]+[m|g|M|G]$" - +## Global script variables FWDIR="$(cd `dirname $0`/..; pwd)" -if [ "$1" = "--help" ] || [ "$1" = "-h" ]; then - echo "Usage: spark-shell [OPTIONS]" - echo "OPTIONS:" - echo "-c --cores num, the maximum number of cores to be used by the spark shell" - echo "-em --execmem num[m|g], the memory used by each executor of spark shell" - echo "-dm --drivermem num[m|g], the memory used by the spark shell and driver" - echo "-h --help, print this help information" - exit -fi +VERBOSE=0 +DRY_RUN=0 +SPARK_REPL_OPTS="${SPARK_REPL_OPTS:-""}" +MASTER="" + +#CLI Color Templates +txtund=$(tput sgr 0 1) # Underline +txtbld=$(tput bold) # Bold +bldred=${txtbld}$(tput setaf 1) # red +bldyel=${txtbld}$(tput setaf 3) # yellow +bldblu=${txtbld}$(tput setaf 4) # blue +bldwht=${txtbld}$(tput setaf 7) # white +txtrst=$(tput sgr0) # Reset +info=${bldwht}*${txtrst} # Feedback +pass=${bldblu}*${txtrst} +warn=${bldred}*${txtrst} +ques=${bldblu}?${txtrst} + +# Helper function to describe the script usage +function usage() { + cat << EOF + +${txtbld}Usage${txtrst}: spark-shell [OPTIONS] + +${txtbld}OPTIONS${txtrst}: + +${txtund}basic${txtrst}: + + -h --help : print this help information. + -c --executor-cores : the maximum number of cores to be used by the spark shell. + -em --executor-memory : num[m|g], the memory used by each executor of spark shell. + -dm --drivermem --driver-memory : num[m|g], the memory used by the spark shell and driver. + +${txtund}soon to be deprecated${txtrst}: + + --cores : please use -c/--executor-cores + +${txtund}other options${txtrst}: -SPARK_SHELL_OPTS="" + -mip --master-ip : Spark Master IP/Host Address + -mp --master-port : num, Spark Master Port + -m --master : full string that describes the Spark Master. + -ld --local-dir : absolute path to a local directory that will be use for "scratch" space in Spark. + -dh --driver-host : hostname or IP address for the driver to listen on. + -dp --driver-port : num, port for the driver to listen on. + -uip --ui-port : num, port for your application's dashboard, which shows memory and workload data. + --parallelism : num, default number of tasks to use across the cluster for distributed shuffle operations when not set by user. + --locality-wait : num, number of milliseconds to wait to launch a data-local task before giving up. + --schedule-fair : flag, enables FAIR scheduling between jobs submitted to the same SparkContext. + --max-failures : num, number of individual task failures before giving up on the job. + --log-conf : flag, log the supplied SparkConf as INFO at start of spark context. + +e.g. + spark-shell -m local -ld /tmp -dh 127.0.0.1 -dp 4001 -uip 4010 --parallelism 10 --locality-wait 500 --schedule-fair --max-failures 100 + +EOF +} + +function out_error(){ + echo -e "${txtund}${bldred}ERROR${txtrst}: $1" + usage + exit 1 +} + +function log_info(){ + [ $VERBOSE -eq 1 ] && echo -e "${bldyel}INFO${txtrst}: $1" +} + +function log_warn(){ + echo -e "${txtund}${bldyel}WARN${txtrst}: $1" +} + +# PATTERNS used to validate more than one optional arg. +ARG_FLAG_PATTERN="^-" +MEM_PATTERN="^[0-9]+[m|g|M|G]$" +NUM_PATTERN="^[0-9]+$" +PORT_PATTERN="^[0-9]+$" -for o in "$@"; do - if [ "$1" = "-c" -o "$1" = "--cores" ]; then - shift +# Setters for optional args. +function set_cores(){ + CORE_PATTERN="^[0-9]+$" if [[ "$1" =~ $CORE_PATTERN ]]; then - SPARK_SHELL_OPTS="$SPARK_SHELL_OPTS -Dspark.cores.max=$1" - shift + SPARK_REPL_OPTS="$SPARK_REPL_OPTS -Dspark.cores.max=$1" else - echo "ERROR: wrong format for -c/--cores" - exit 1 + out_error "wrong format for $2" fi - fi - if [ "$1" = "-em" -o "$1" = "--execmem" ]; then - shift +} + +function set_em(){ if [[ $1 =~ $MEM_PATTERN ]]; then - SPARK_SHELL_OPTS="$SPARK_SHELL_OPTS -Dspark.executor.memory=$1" - shift + SPARK_REPL_OPTS="$SPARK_REPL_OPTS -Dspark.executor.memory=$1" else - echo "ERROR: wrong format for --execmem/-em" - exit 1 + out_error "wrong format for $2" fi - fi - if [ "$1" = "-dm" -o "$1" = "--drivermem" ]; then - shift +} + +function set_dm(){ if [[ $1 =~ $MEM_PATTERN ]]; then - export SPARK_MEM=$1 - shift + export SPARK_DRIVER_MEMORY=$1 else - echo "ERROR: wrong format for --drivermem/-dm" - exit 1 + out_error "wrong format for $2" fi - fi -done +} -# Set MASTER from spark-env if possible -DEFAULT_SPARK_MASTER_PORT=7077 -if [ -z "$MASTER" ]; then - if [ -e "$FWDIR/conf/spark-env.sh" ]; then - . "$FWDIR/conf/spark-env.sh" - fi - if [ "x" != "x$SPARK_MASTER_IP" ]; then - if [ "y" != "y$SPARK_MASTER_PORT" ]; then - SPARK_MASTER_PORT="${SPARK_MASTER_PORT}" +function set_localdir(){ + LOCAL_DIR_PATTERN="\/.+" + if [[ "$1" =~ $LOCAL_DIR_PATTERN ]]; then + SPARK_REPL_OPTS="$SPARK_REPL_OPTS -Dspark.local.dir=$1" else - SPARK_MASTER_PORT=$DEFAULT_SPARK_MASTER_PORT + out_error "wrong format for $2" fi - export MASTER="spark://${SPARK_MASTER_IP}:${SPARK_MASTER_PORT}" - fi -fi +} + +function set_driver_host(){ + if ! [[ "$1" =~ $ARG_FLAG_PATTERN ]]; then + SPARK_REPL_OPTS="$SPARK_REPL_OPTS -Dspark.driver.host=$1" + else + out_error "wrong format for $2" + fi +} + +function set_driver_port(){ + if [[ $1 =~ $PORT_PATTERN ]]; then + SPARK_REPL_OPTS="$SPARK_REPL_OPTS -Dspark.driver.port=$1" + else + out_error "wrong format for $2" + fi +} + +function set_uip(){ + if [[ $1 =~ $PORT_PATTERN ]]; then + SPARK_REPL_OPTS="$SPARK_REPL_OPTS -Dspark.ui.port=$1" + else + out_error "wrong format for $2" + fi +} + +function set_parallelism(){ + if [[ $1 =~ $NUM_PATTERN ]]; then + SPARK_REPL_OPTS="$SPARK_REPL_OPTS -Dspark.default.parallelism=$1" + else + out_error "wrong format for $2" + fi +} + +function set_locality_wait(){ + if [[ $1 =~ $NUM_PATTERN ]]; then + SPARK_REPL_OPTS="$SPARK_REPL_OPTS -Dspark.locality.wait=$1" + else + out_error "wrong format for $2" + fi +} + +function set_spark_scheduler(){ + SPARK_REPL_OPTS="$SPARK_REPL_OPTS -Dspark.scheduler.mode=$1" +} + +function set_spark_max_failures(){ + if [[ $1 =~ $NUM_PATTERN ]]; then + SPARK_REPL_OPTS="$SPARK_REPL_OPTS -Dspark.task.maxFailures=$1" + else + out_error "wrong format for $2" + fi +} + +function set_spark_log_conf(){ + SPARK_REPL_OPTS="$SPARK_REPL_OPTS -Dspark.logConf=$1" +} + +function set_spark_master_ip() { + if ! [[ "$1" =~ $ARG_FLAG_PATTERN ]]; then + SPARK_MASTER_IP="$1" + else + out_error "wrong format for $2" + fi +} + +function set_spark_master_port() { + if [[ $1 =~ $PORT_PATTERN ]]; then + SPARK_MASTER_PORT="$1" + else + out_error "wrong format for $2" + fi +} + +function set_spark_master(){ + if ! [[ "$1" =~ $ARG_FLAG_PATTERN ]]; then + MASTER="$1" + else + out_error "wrong format for $2" + fi +} + + +function resolve_spark_master(){ + # Set MASTER from spark-env if possible + DEFAULT_SPARK_MASTER_PORT=7077 + if [ -z "$MASTER" ]; then + if [ -e "$FWDIR/conf/spark-env.sh" ]; then + . "$FWDIR/conf/spark-env.sh" + fi + if [ -n "$SPARK_MASTER_IP" ]; then + SPARK_MASTER_PORT="${SPARK_MASTER_PORT:-"$DEFAULT_SPARK_MASTER_PORT"}" + export MASTER="spark://${SPARK_MASTER_IP}:${SPARK_MASTER_PORT}" + fi + fi + + if [ -z "$MASTER" ]; then + out_error "Unable to define a Spark Master, please either define a $FWDIR/conf/spark-env.sh or see usage with -h" + fi + +} + +function main(){ + log_info "Base Directory set to $FWDIR" + + resolve_spark_master + log_info "Spark Master is $MASTER" + + log_info "Spark REPL options $SPARK_REPL_OPTS" + if $cygwin; then + # Workaround for issue involving JLine and Cygwin + # (see http://sourceforge.net/p/jline/bugs/40/). + # If you're using the Mintty terminal emulator in Cygwin, may need to set the + # "Backspace sends ^H" setting in "Keys" section of the Mintty options + # (see https://github.com/sbt/sbt/issues/562). + stty -icanon min 1 -echo > /dev/null 2>&1 + export SPARK_REPL_OPTS="$SPARK_REPL_OPTS -Djline.terminal=unix" + $FWDIR/bin/spark-class org.apache.spark.repl.Main "$@" + stty icanon echo > /dev/null 2>&1 + else + export SPARK_REPL_OPTS + $FWDIR/bin/spark-class org.apache.spark.repl.Main "$@" + fi + + +} + +for option in "$@" +do + case $option in + -h | --help ) + usage + exit 1 + ;; + -c | --executor-cores) + shift + _1=$1 + shift + set_cores $_1 "-c/--executor-cores" + ;; + --cores) + shift + _1=$1 + shift + log_warn "The option --cores has been deprecated, please use -c/--executor-cores instead." + set_cores $_1 "--cores" + ;; + -em | --executor-memory) + shift + _1=$1 + shift + set_em $_1 "-em/--executor-memory" + ;; + -dm | --drivermem | --driver-memory) + shift + _1=$1 + shift + set_dm $_1 "-dm/--drivermem/--driver-memory" + ;; +# --drivermem) +# shift +# _1=$1 +# shift +# log_warn "The option --drivermem will soon be deprecated, please use -dm/--driver-memory instead." +# set_dm $_1 "--drivermem" +# ;; + -ld | --local-dir) + shift + _1=$1 + shift + set_localdir $_1 "-ld/--local-dir" + ;; + -dh | --driver-host) + shift + _1=$1 + shift + set_driver_host $_1 "-dh/--driver-host" + ;; + -dp | --driver-port) + shift + _1=$1 + shift + set_driver_port $_1 "-dp/--driver-port" + ;; + -mip | --master-ip) + shift + _1=$1 + shift + set_spark_master_ip $_1 "-mip/--master-ip" + ;; + -mp | --master-port) + shift + _1=$1 + shift + set_spark_master_port $_1 "-mp/--master-port" + ;; + -m | --master) + shift + _1=$1 + shift + set_spark_master $_1 "-m/--master" + ;; + -uip | --ui-port) + shift + _1=$1 + shift + set_uip $_1 "-uip/--ui-port" + ;; + --parallelism) + shift + _1=$1 + shift + set_parallelism $_1 "--parallelism" + ;; + --locality-wait) + shift + _1=$1 + shift + set_locality_wait $_1 "--locality-wait" + ;; + --schedule-fair) + shift + set_spark_scheduler "FAIR" + ;; + --max-failures) + shift + _1=$1 + shift + set_spark_max_failures "$_1" "--max-failures" + ;; + --log-conf) + shift + set_spark_log_conf "true" + ;; + -v | --verbose ) + shift + VERBOSE=1 + ;; + --dry-run) + shift + DRY_RUN=1 + ;; + + ?) + ;; + esac +done # Copy restore-TTY-on-exit functions from Scala script so spark-shell exits properly even in # binary distribution of Spark where Scala is not installed @@ -124,20 +420,10 @@ if [[ ! $? ]]; then saved_stty="" fi -if $cygwin; then - # Workaround for issue involving JLine and Cygwin - # (see http://sourceforge.net/p/jline/bugs/40/). - # If you're using the Mintty terminal emulator in Cygwin, may need to set the - # "Backspace sends ^H" setting in "Keys" section of the Mintty options - # (see https://github.com/sbt/sbt/issues/562). - stty -icanon min 1 -echo > /dev/null 2>&1 - $FWDIR/bin/spark-class -Djline.terminal=unix $SPARK_SHELL_OPTS org.apache.spark.repl.Main "$@" - stty icanon echo > /dev/null 2>&1 -else - $FWDIR/bin/spark-class $SPARK_SHELL_OPTS org.apache.spark.repl.Main "$@" -fi +main # record the exit status lest it be overwritten: # then reenable echo and propagate the code. exit_status=$? onExit + diff --git a/core/pom.xml b/core/pom.xml index 99c841472b3eb..4d7d41a9714d7 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -17,258 +17,256 @@ --> - 4.0.0 - - org.apache.spark - spark-parent - 1.0.0-SNAPSHOT - ../pom.xml - - + 4.0.0 + org.apache.spark - spark-core_2.10 - jar - Spark Project Core - http://spark.apache.org/ + spark-parent + 1.0.0-SNAPSHOT + ../pom.xml + - - - - yarn-alpha - - - org.apache.avro - avro - - - - + org.apache.spark + spark-core_2.10 + jar + Spark Project Core + http://spark.apache.org/ - - - org.apache.hadoop - hadoop-client - - - net.java.dev.jets3t - jets3t - - - commons-logging - commons-logging - - - - - org.apache.curator - curator-recipes - - - org.eclipse.jetty - jetty-server - - - com.google.guava - guava - - - com.google.code.findbugs - jsr305 - - - org.slf4j - slf4j-api - - - org.slf4j - jul-to-slf4j - - - org.slf4j - jcl-over-slf4j - - - log4j - log4j - - - org.slf4j - slf4j-log4j12 - - - com.ning - compress-lzf - - - org.xerial.snappy - snappy-java - - - org.ow2.asm - asm - - - com.twitter - chill_${scala.binary.version} - 0.3.1 - - - com.twitter - chill-java - 0.3.1 - - - ${akka.group} - akka-remote_${scala.binary.version} - - - ${akka.group} - akka-slf4j_${scala.binary.version} - - - ${akka.group} - akka-testkit_${scala.binary.version} - test - - - org.scala-lang - scala-library - - - org.json4s - json4s-jackson_${scala.binary.version} - 3.2.6 - - - - org.scala-lang - scalap - - - - - it.unimi.dsi - fastutil - - - colt - colt - - - org.apache.mesos - mesos - - - io.netty - netty-all - - - com.clearspring.analytics - stream - - - com.codahale.metrics - metrics-core - - - com.codahale.metrics - metrics-jvm - - - com.codahale.metrics - metrics-json - - - com.codahale.metrics - metrics-ganglia - - - com.codahale.metrics - metrics-graphite - - - org.apache.derby - derby - test - - - commons-io - commons-io - test - - - org.scalatest - scalatest_${scala.binary.version} - test - - - org.mockito - mockito-all - test - - - org.scalacheck - scalacheck_${scala.binary.version} - test - - - org.easymock - easymock - test - - - com.novocode - junit-interface - test - - - - target/scala-${scala.binary.version}/classes - target/scala-${scala.binary.version}/test-classes - - - org.apache.maven.plugins - maven-antrun-plugin - - - test - - run - - - true - - - - - - - - - - - - - - - - - - - - org.scalatest - scalatest-maven-plugin - - - ${basedir}/.. - 1 - ${spark.classpath} - - - - - + + + org.apache.hadoop + hadoop-client + + + net.java.dev.jets3t + jets3t + + + commons-logging + commons-logging + + + + + org.apache.curator + curator-recipes + + + org.eclipse.jetty + jetty-plus + + + org.eclipse.jetty + jetty-security + + + org.eclipse.jetty + jetty-util + + + org.eclipse.jetty + jetty-server + + + com.google.guava + guava + + + com.google.code.findbugs + jsr305 + + + org.slf4j + slf4j-api + + + org.slf4j + jul-to-slf4j + + + org.slf4j + jcl-over-slf4j + + + log4j + log4j + + + org.slf4j + slf4j-log4j12 + + + com.ning + compress-lzf + + + org.xerial.snappy + snappy-java + + + com.twitter + chill_${scala.binary.version} + 0.3.1 + + + com.twitter + chill-java + 0.3.1 + + + commons-net + commons-net + + + ${akka.group} + akka-remote_${scala.binary.version} + + + ${akka.group} + akka-slf4j_${scala.binary.version} + + + ${akka.group} + akka-testkit_${scala.binary.version} + test + + + org.scala-lang + scala-library + + + org.json4s + json4s-jackson_${scala.binary.version} + 3.2.6 + + + + org.scala-lang + scalap + + + + + it.unimi.dsi + fastutil + + + colt + colt + + + org.apache.mesos + mesos + + + io.netty + netty-all + + + com.clearspring.analytics + stream + + + com.codahale.metrics + metrics-core + + + com.codahale.metrics + metrics-jvm + + + com.codahale.metrics + metrics-json + + + com.codahale.metrics + metrics-ganglia + + + com.codahale.metrics + metrics-graphite + + + org.apache.derby + derby + test + + + commons-io + commons-io + test + + + org.scalatest + scalatest_${scala.binary.version} + test + + + org.mockito + mockito-all + test + + + org.scalacheck + scalacheck_${scala.binary.version} + test + + + org.easymock + easymock + test + + + com.novocode + junit-interface + test + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + + org.apache.maven.plugins + maven-antrun-plugin + + + test + + run + + + true + + + + + + + + + + + + + + + + + + + + org.scalatest + scalatest-maven-plugin + + + ${basedir}/.. + 1 + ${spark.classpath} + + + + + diff --git a/core/src/main/scala/org/apache/spark/CacheManager.scala b/core/src/main/scala/org/apache/spark/CacheManager.scala index 1daabecf23292..872e892c04fe6 100644 --- a/core/src/main/scala/org/apache/spark/CacheManager.scala +++ b/core/src/main/scala/org/apache/spark/CacheManager.scala @@ -71,10 +71,30 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { val computedValues = rdd.computeOrReadCheckpoint(split, context) // Persist the result, so long as the task is not running locally if (context.runningLocally) { return computedValues } - val elements = new ArrayBuffer[Any] - elements ++= computedValues - blockManager.put(key, elements, storageLevel, tellMaster = true) - elements.iterator.asInstanceOf[Iterator[T]] + if (storageLevel.useDisk && !storageLevel.useMemory) { + // In the case that this RDD is to be persisted using DISK_ONLY + // the iterator will be passed directly to the blockManager (rather then + // caching it to an ArrayBuffer first), then the resulting block data iterator + // will be passed back to the user. If the iterator generates a lot of data, + // this means that it doesn't all have to be held in memory at one time. + // This could also apply to MEMORY_ONLY_SER storage, but we need to make sure + // blocks aren't dropped by the block store before enabling that. + blockManager.put(key, computedValues, storageLevel, tellMaster = true) + return blockManager.get(key) match { + case Some(values) => + return new InterruptibleIterator(context, values.asInstanceOf[Iterator[T]]) + case None => + logInfo("Failure to store %s".format(key)) + throw new Exception("Block manager failed to return persisted valued") + } + } else { + // In this case the RDD is cached to an array buffer. This will save the results + // if we're dealing with a 'one-time' iterator + val elements = new ArrayBuffer[Any] + elements ++= computedValues + blockManager.put(key, elements, storageLevel, tellMaster = true) + return elements.iterator.asInstanceOf[Iterator[T]] + } } finally { loading.synchronized { loading.remove(key) diff --git a/core/src/main/scala/org/apache/spark/HttpFileServer.scala b/core/src/main/scala/org/apache/spark/HttpFileServer.scala index d3264a4bb3c81..3d7692ea8a49e 100644 --- a/core/src/main/scala/org/apache/spark/HttpFileServer.scala +++ b/core/src/main/scala/org/apache/spark/HttpFileServer.scala @@ -23,7 +23,7 @@ import com.google.common.io.Files import org.apache.spark.util.Utils -private[spark] class HttpFileServer extends Logging { +private[spark] class HttpFileServer(securityManager: SecurityManager) extends Logging { var baseDir : File = null var fileDir : File = null @@ -38,9 +38,10 @@ private[spark] class HttpFileServer extends Logging { fileDir.mkdir() jarDir.mkdir() logInfo("HTTP File server directory is " + baseDir) - httpServer = new HttpServer(baseDir) + httpServer = new HttpServer(baseDir, securityManager) httpServer.start() serverUri = httpServer.uri + logDebug("HTTP file server started at: " + serverUri) } def stop() { diff --git a/core/src/main/scala/org/apache/spark/HttpServer.scala b/core/src/main/scala/org/apache/spark/HttpServer.scala index 759e68ee0cc61..cb5df25fa48df 100644 --- a/core/src/main/scala/org/apache/spark/HttpServer.scala +++ b/core/src/main/scala/org/apache/spark/HttpServer.scala @@ -19,15 +19,18 @@ package org.apache.spark import java.io.File +import org.eclipse.jetty.util.security.{Constraint, Password} +import org.eclipse.jetty.security.authentication.DigestAuthenticator +import org.eclipse.jetty.security.{ConstraintMapping, ConstraintSecurityHandler, HashLoginService, SecurityHandler} + import org.eclipse.jetty.server.Server import org.eclipse.jetty.server.bio.SocketConnector -import org.eclipse.jetty.server.handler.DefaultHandler -import org.eclipse.jetty.server.handler.HandlerList -import org.eclipse.jetty.server.handler.ResourceHandler +import org.eclipse.jetty.server.handler.{DefaultHandler, HandlerList, ResourceHandler} import org.eclipse.jetty.util.thread.QueuedThreadPool import org.apache.spark.util.Utils + /** * Exception type thrown by HttpServer when it is in the wrong state for an operation. */ @@ -38,7 +41,8 @@ private[spark] class ServerStateException(message: String) extends Exception(mes * as well as classes created by the interpreter when the user types in code. This is just a wrapper * around a Jetty server. */ -private[spark] class HttpServer(resourceBase: File) extends Logging { +private[spark] class HttpServer(resourceBase: File, securityManager: SecurityManager) + extends Logging { private var server: Server = null private var port: Int = -1 @@ -59,14 +63,60 @@ private[spark] class HttpServer(resourceBase: File) extends Logging { server.setThreadPool(threadPool) val resHandler = new ResourceHandler resHandler.setResourceBase(resourceBase.getAbsolutePath) + val handlerList = new HandlerList handlerList.setHandlers(Array(resHandler, new DefaultHandler)) - server.setHandler(handlerList) + + if (securityManager.isAuthenticationEnabled()) { + logDebug("HttpServer is using security") + val sh = setupSecurityHandler(securityManager) + // make sure we go through security handler to get resources + sh.setHandler(handlerList) + server.setHandler(sh) + } else { + logDebug("HttpServer is not using security") + server.setHandler(handlerList) + } + server.start() port = server.getConnectors()(0).getLocalPort() } } + /** + * Setup Jetty to the HashLoginService using a single user with our + * shared secret. Configure it to use DIGEST-MD5 authentication so that the password + * isn't passed in plaintext. + */ + private def setupSecurityHandler(securityMgr: SecurityManager): ConstraintSecurityHandler = { + val constraint = new Constraint() + // use DIGEST-MD5 as the authentication mechanism + constraint.setName(Constraint.__DIGEST_AUTH) + constraint.setRoles(Array("user")) + constraint.setAuthenticate(true) + constraint.setDataConstraint(Constraint.DC_NONE) + + val cm = new ConstraintMapping() + cm.setConstraint(constraint) + cm.setPathSpec("/*") + val sh = new ConstraintSecurityHandler() + + // the hashLoginService lets us do a single user and + // secret right now. This could be changed to use the + // JAASLoginService for other options. + val hashLogin = new HashLoginService() + + val userCred = new Password(securityMgr.getSecretKey()) + if (userCred == null) { + throw new Exception("Error: secret key is null with authentication on") + } + hashLogin.putUser(securityMgr.getHttpUser(), userCred, Array("user")) + sh.setLoginService(hashLogin) + sh.setAuthenticator(new DigestAuthenticator()); + sh.setConstraintMappings(Array(cm)) + sh + } + def stop() { if (server == null) { throw new ServerStateException("Server is already stopped") diff --git a/core/src/main/scala/org/apache/spark/Logging.scala b/core/src/main/scala/org/apache/spark/Logging.scala index b749e5414dab6..7423082e34f47 100644 --- a/core/src/main/scala/org/apache/spark/Logging.scala +++ b/core/src/main/scala/org/apache/spark/Logging.scala @@ -19,6 +19,7 @@ package org.apache.spark import org.apache.log4j.{LogManager, PropertyConfigurator} import org.slf4j.{Logger, LoggerFactory} +import org.slf4j.impl.StaticLoggerBinder /** * Utility trait for classes that want to log data. Creates a SLF4J logger for the class and allows @@ -101,9 +102,11 @@ trait Logging { } private def initializeLogging() { - // If Log4j doesn't seem initialized, load a default properties file + // If Log4j is being used, but is not initialized, load a default properties file + val binder = StaticLoggerBinder.getSingleton + val usingLog4j = binder.getLoggerFactoryClassStr.endsWith("Log4jLoggerFactory") val log4jInitialized = LogManager.getRootLogger.getAllAppenders.hasMoreElements - if (!log4jInitialized) { + if (!log4jInitialized && usingLog4j) { val defaultLogProps = "org/apache/spark/log4j-defaults.properties" val classLoader = this.getClass.getClassLoader Option(classLoader.getResource(defaultLogProps)) match { diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala new file mode 100644 index 0000000000000..591978c1d3630 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -0,0 +1,253 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import java.net.{Authenticator, PasswordAuthentication} +import org.apache.hadoop.io.Text +import org.apache.hadoop.security.Credentials +import org.apache.hadoop.security.UserGroupInformation +import org.apache.spark.deploy.SparkHadoopUtil + +import scala.collection.mutable.ArrayBuffer + +/** + * Spark class responsible for security. + * + * In general this class should be instantiated by the SparkEnv and most components + * should access it from that. There are some cases where the SparkEnv hasn't been + * initialized yet and this class must be instantiated directly. + * + * Spark currently supports authentication via a shared secret. + * Authentication can be configured to be on via the 'spark.authenticate' configuration + * parameter. This parameter controls whether the Spark communication protocols do + * authentication using the shared secret. This authentication is a basic handshake to + * make sure both sides have the same shared secret and are allowed to communicate. + * If the shared secret is not identical they will not be allowed to communicate. + * + * The Spark UI can also be secured by using javax servlet filters. A user may want to + * secure the UI if it has data that other users should not be allowed to see. The javax + * servlet filter specified by the user can authenticate the user and then once the user + * is logged in, Spark can compare that user versus the view acls to make sure they are + * authorized to view the UI. The configs 'spark.ui.acls.enable' and 'spark.ui.view.acls' + * control the behavior of the acls. Note that the person who started the application + * always has view access to the UI. + * + * Spark does not currently support encryption after authentication. + * + * At this point spark has multiple communication protocols that need to be secured and + * different underlying mechanisms are used depending on the protocol: + * + * - Akka -> The only option here is to use the Akka Remote secure-cookie functionality. + * Akka remoting allows you to specify a secure cookie that will be exchanged + * and ensured to be identical in the connection handshake between the client + * and the server. If they are not identical then the client will be refused + * to connect to the server. There is no control of the underlying + * authentication mechanism so its not clear if the password is passed in + * plaintext or uses DIGEST-MD5 or some other mechanism. + * Akka also has an option to turn on SSL, this option is not currently supported + * but we could add a configuration option in the future. + * + * - HTTP for broadcast and file server (via HttpServer) -> Spark currently uses Jetty + * for the HttpServer. Jetty supports multiple authentication mechanisms - + * Basic, Digest, Form, Spengo, etc. It also supports multiple different login + * services - Hash, JAAS, Spnego, JDBC, etc. Spark currently uses the HashLoginService + * to authenticate using DIGEST-MD5 via a single user and the shared secret. + * Since we are using DIGEST-MD5, the shared secret is not passed on the wire + * in plaintext. + * We currently do not support SSL (https), but Jetty can be configured to use it + * so we could add a configuration option for this in the future. + * + * The Spark HttpServer installs the HashLoginServer and configures it to DIGEST-MD5. + * Any clients must specify the user and password. There is a default + * Authenticator installed in the SecurityManager to how it does the authentication + * and in this case gets the user name and password from the request. + * + * - ConnectionManager -> The Spark ConnectionManager uses java nio to asynchronously + * exchange messages. For this we use the Java SASL + * (Simple Authentication and Security Layer) API and again use DIGEST-MD5 + * as the authentication mechanism. This means the shared secret is not passed + * over the wire in plaintext. + * Note that SASL is pluggable as to what mechanism it uses. We currently use + * DIGEST-MD5 but this could be changed to use Kerberos or other in the future. + * Spark currently supports "auth" for the quality of protection, which means + * the connection is not supporting integrity or privacy protection (encryption) + * after authentication. SASL also supports "auth-int" and "auth-conf" which + * SPARK could be support in the future to allow the user to specify the quality + * of protection they want. If we support those, the messages will also have to + * be wrapped and unwrapped via the SaslServer/SaslClient.wrap/unwrap API's. + * + * Since the connectionManager does asynchronous messages passing, the SASL + * authentication is a bit more complex. A ConnectionManager can be both a client + * and a Server, so for a particular connection is has to determine what to do. + * A ConnectionId was added to be able to track connections and is used to + * match up incoming messages with connections waiting for authentication. + * If its acting as a client and trying to send a message to another ConnectionManager, + * it blocks the thread calling sendMessage until the SASL negotiation has occurred. + * The ConnectionManager tracks all the sendingConnections using the ConnectionId + * and waits for the response from the server and does the handshake. + * + * - HTTP for the Spark UI -> the UI was changed to use servlets so that javax servlet filters + * can be used. Yarn requires a specific AmIpFilter be installed for security to work + * properly. For non-Yarn deployments, users can write a filter to go through a + * companies normal login service. If an authentication filter is in place then the + * SparkUI can be configured to check the logged in user against the list of users who + * have view acls to see if that user is authorized. + * The filters can also be used for many different purposes. For instance filters + * could be used for logging, encryption, or compression. + * + * The exact mechanisms used to generate/distributed the shared secret is deployment specific. + * + * For Yarn deployments, the secret is automatically generated using the Akka remote + * Crypt.generateSecureCookie() API. The secret is placed in the Hadoop UGI which gets passed + * around via the Hadoop RPC mechanism. Hadoop RPC can be configured to support different levels + * of protection. See the Hadoop documentation for more details. Each Spark application on Yarn + * gets a different shared secret. On Yarn, the Spark UI gets configured to use the Hadoop Yarn + * AmIpFilter which requires the user to go through the ResourceManager Proxy. That Proxy is there + * to reduce the possibility of web based attacks through YARN. Hadoop can be configured to use + * filters to do authentication. That authentication then happens via the ResourceManager Proxy + * and Spark will use that to do authorization against the view acls. + * + * For other Spark deployments, the shared secret must be specified via the + * spark.authenticate.secret config. + * All the nodes (Master and Workers) and the applications need to have the same shared secret. + * This again is not ideal as one user could potentially affect another users application. + * This should be enhanced in the future to provide better protection. + * If the UI needs to be secured the user needs to install a javax servlet filter to do the + * authentication. Spark will then use that user to compare against the view acls to do + * authorization. If not filter is in place the user is generally null and no authorization + * can take place. + */ + +private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging { + + // key used to store the spark secret in the Hadoop UGI + private val sparkSecretLookupKey = "sparkCookie" + + private val authOn = sparkConf.getBoolean("spark.authenticate", false) + private val uiAclsOn = sparkConf.getBoolean("spark.ui.acls.enable", false) + + // always add the current user and SPARK_USER to the viewAcls + private val aclUsers = ArrayBuffer[String](System.getProperty("user.name", ""), + Option(System.getenv("SPARK_USER")).getOrElse("")) + aclUsers ++= sparkConf.get("spark.ui.view.acls", "").split(',') + private val viewAcls = aclUsers.map(_.trim()).filter(!_.isEmpty).toSet + + private val secretKey = generateSecretKey() + logInfo("SecurityManager, is authentication enabled: " + authOn + + " are ui acls enabled: " + uiAclsOn + " users with view permissions: " + viewAcls.toString()) + + // Set our own authenticator to properly negotiate user/password for HTTP connections. + // This is needed by the HTTP client fetching from the HttpServer. Put here so its + // only set once. + if (authOn) { + Authenticator.setDefault( + new Authenticator() { + override def getPasswordAuthentication(): PasswordAuthentication = { + var passAuth: PasswordAuthentication = null + val userInfo = getRequestingURL().getUserInfo() + if (userInfo != null) { + val parts = userInfo.split(":", 2) + passAuth = new PasswordAuthentication(parts(0), parts(1).toCharArray()) + } + return passAuth + } + } + ) + } + + /** + * Generates or looks up the secret key. + * + * The way the key is stored depends on the Spark deployment mode. Yarn + * uses the Hadoop UGI. + * + * For non-Yarn deployments, If the config variable is not set + * we throw an exception. + */ + private def generateSecretKey(): String = { + if (!isAuthenticationEnabled) return null + // first check to see if the secret is already set, else generate a new one if on yarn + val sCookie = if (SparkHadoopUtil.get.isYarnMode) { + val secretKey = SparkHadoopUtil.get.getSecretKeyFromUserCredentials(sparkSecretLookupKey) + if (secretKey != null) { + logDebug("in yarn mode, getting secret from credentials") + return new Text(secretKey).toString + } else { + logDebug("getSecretKey: yarn mode, secret key from credentials is null") + } + val cookie = akka.util.Crypt.generateSecureCookie + // if we generated the secret then we must be the first so lets set it so t + // gets used by everyone else + SparkHadoopUtil.get.addSecretKeyToUserCredentials(sparkSecretLookupKey, cookie) + logInfo("adding secret to credentials in yarn mode") + cookie + } else { + // user must have set spark.authenticate.secret config + sparkConf.getOption("spark.authenticate.secret") match { + case Some(value) => value + case None => throw new Exception("Error: a secret key must be specified via the " + + "spark.authenticate.secret config") + } + } + sCookie + } + + /** + * Check to see if Acls for the UI are enabled + * @return true if UI authentication is enabled, otherwise false + */ + def uiAclsEnabled(): Boolean = uiAclsOn + + /** + * Checks the given user against the view acl list to see if they have + * authorization to view the UI. If the UI acls must are disabled + * via spark.ui.acls.enable, all users have view access. + * + * @param user to see if is authorized + * @return true is the user has permission, otherwise false + */ + def checkUIViewPermissions(user: String): Boolean = { + if (uiAclsEnabled() && (user != null) && (!viewAcls.contains(user))) false else true + } + + /** + * Check to see if authentication for the Spark communication protocols is enabled + * @return true if authentication is enabled, otherwise false + */ + def isAuthenticationEnabled(): Boolean = authOn + + /** + * Gets the user used for authenticating HTTP connections. + * For now use a single hardcoded user. + * @return the HTTP user as a String + */ + def getHttpUser(): String = "sparkHttpUser" + + /** + * Gets the user used for authenticating SASL connections. + * For now use a single hardcoded user. + * @return the SASL user as a String + */ + def getSaslUser(): String = "sparkSaslUser" + + /** + * Gets the secret key. + * @return the secret key as a String if authentication is enabled, otherwise returns null + */ + def getSecretKey(): String = secretKey +} diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index da778aa851cd2..cdc0e5a34240e 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -130,6 +130,8 @@ class SparkContext( val isLocal = (master == "local" || master.startsWith("local[")) + if (master == "yarn-client") System.setProperty("SPARK_YARN_MODE", "true") + // Create the Spark execution environment (cache, map output tracker, etc) private[spark] val env = SparkEnv.create( conf, @@ -160,19 +162,20 @@ class SparkContext( jars.foreach(addJar) } + def warnSparkMem(value: String): String = { + logWarning("Using SPARK_MEM to set amount of memory to use per executor process is " + + "deprecated, please use spark.executor.memory instead.") + value + } + private[spark] val executorMemory = conf.getOption("spark.executor.memory") - .orElse(Option(System.getenv("SPARK_MEM"))) + .orElse(Option(System.getenv("SPARK_EXECUTOR_MEMORY"))) + .orElse(Option(System.getenv("SPARK_MEM")).map(warnSparkMem)) .map(Utils.memoryStringToMb) .getOrElse(512) - if (!conf.contains("spark.executor.memory") && sys.env.contains("SPARK_MEM")) { - logWarning("Using SPARK_MEM to set amount of memory to use per executor process is " + - "deprecated, instead use spark.executor.memory") - } - // Environment variables to pass to our executors private[spark] val executorEnvs = HashMap[String, String]() - // Note: SPARK_MEM is included for Mesos, but overwritten for standalone mode in ExecutorRunner for (key <- Seq("SPARK_CLASSPATH", "SPARK_LIBRARY_PATH", "SPARK_JAVA_OPTS"); value <- Option(System.getenv(key))) { executorEnvs(key) = value @@ -183,8 +186,9 @@ class SparkContext( value <- Option(System.getenv(envKey)).orElse(Option(System.getProperty(propKey)))} { executorEnvs(envKey) = value } - // Since memory can be set with a system property too, use that - executorEnvs("SPARK_MEM") = executorMemory + "m" + // The Mesos scheduler backend relies on this environment variable to set executor memory. + // TODO: Set this only in the Mesos scheduler. + executorEnvs("SPARK_EXECUTOR_MEMORY") = executorMemory + "m" executorEnvs ++= conf.getExecutorEnv // Set SPARK_USER for user who is running SparkContext. @@ -634,7 +638,7 @@ class SparkContext( addedFiles(key) = System.currentTimeMillis // Fetch the file locally in case a job is executed using DAGScheduler.runLocally(). - Utils.fetchFile(path, new File(SparkFiles.getRootDirectory), conf) + Utils.fetchFile(path, new File(SparkFiles.getRootDirectory), conf, env.securityManager) logInfo("Added file " + path + " at " + key + " with timestamp " + addedFiles(key)) } @@ -736,8 +740,10 @@ class SparkContext( key = uri.getScheme match { // A JAR file which exists only on the driver node case null | "file" => - if (SparkHadoopUtil.get.isYarnMode() && master == "yarn-standalone") { - // In order for this to work in yarn standalone mode the user must specify the + // yarn-standalone is deprecated, but still supported + if (SparkHadoopUtil.get.isYarnMode() && + (master == "yarn-standalone" || master == "yarn-cluster")) { + // In order for this to work in yarn-cluster mode the user must specify the // --addjars option to the client to upload the file into the distributed cache // of the AM to make it show up in the current working directory. val fileName = new Path(uri.getPath).getName() @@ -1025,7 +1031,7 @@ class SparkContext( * The SparkContext object contains a number of implicit conversions and parameters for use with * various Spark features. */ -object SparkContext { +object SparkContext extends Logging { private[spark] val SPARK_JOB_DESCRIPTION = "spark.job.description" @@ -1243,7 +1249,11 @@ object SparkContext { } scheduler - case "yarn-standalone" => + case "yarn-standalone" | "yarn-cluster" => + if (master == "yarn-standalone") { + logWarning( + "\"yarn-standalone\" is deprecated as of Spark 1.0. Use \"yarn-cluster\" instead.") + } val scheduler = try { val clazz = Class.forName("org.apache.spark.scheduler.cluster.YarnClusterScheduler") val cons = clazz.getConstructor(classOf[SparkContext]) diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 7ac65828f670f..5e43b5198422c 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -53,7 +53,8 @@ class SparkEnv private[spark] ( val httpFileServer: HttpFileServer, val sparkFilesDir: String, val metricsSystem: MetricsSystem, - val conf: SparkConf) extends Logging { + val conf: SparkConf, + val securityManager: SecurityManager) extends Logging { // A mapping of thread ID to amount of memory used for shuffle in bytes // All accesses should be manually synchronized @@ -122,8 +123,9 @@ object SparkEnv extends Logging { isDriver: Boolean, isLocal: Boolean): SparkEnv = { - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, port, - conf = conf) + val securityManager = new SecurityManager(conf) + val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, port, conf = conf, + securityManager = securityManager) // Bit of a hack: If this is the driver and our port was 0 (meaning bind to any free port), // figure out which port number Akka actually bound to and set spark.driver.port to it. @@ -139,7 +141,6 @@ object SparkEnv extends Logging { val name = conf.get(propertyName, defaultClassName) Class.forName(name, true, classLoader).newInstance().asInstanceOf[T] } - val serializerManager = new SerializerManager val serializer = serializerManager.setDefault( @@ -167,12 +168,12 @@ object SparkEnv extends Logging { val blockManagerMaster = new BlockManagerMaster(registerOrLookup( "BlockManagerMaster", new BlockManagerMasterActor(isLocal, conf)), conf) - val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster, - serializer, conf) + val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster, + serializer, conf, securityManager) val connectionManager = blockManager.connectionManager - val broadcastManager = new BroadcastManager(isDriver, conf) + val broadcastManager = new BroadcastManager(isDriver, conf, securityManager) val cacheManager = new CacheManager(blockManager) @@ -190,14 +191,14 @@ object SparkEnv extends Logging { val shuffleFetcher = instantiateClass[ShuffleFetcher]( "spark.shuffle.fetcher", "org.apache.spark.BlockStoreShuffleFetcher") - val httpFileServer = new HttpFileServer() + val httpFileServer = new HttpFileServer(securityManager) httpFileServer.initialize() conf.set("spark.fileserver.uri", httpFileServer.serverUri) val metricsSystem = if (isDriver) { - MetricsSystem.createMetricsSystem("driver", conf) + MetricsSystem.createMetricsSystem("driver", conf, securityManager) } else { - MetricsSystem.createMetricsSystem("executor", conf) + MetricsSystem.createMetricsSystem("executor", conf, securityManager) } metricsSystem.start() @@ -231,6 +232,7 @@ object SparkEnv extends Logging { httpFileServer, sparkFilesDir, metricsSystem, - conf) + conf, + securityManager) } } diff --git a/core/src/main/scala/org/apache/spark/SparkSaslClient.scala b/core/src/main/scala/org/apache/spark/SparkSaslClient.scala new file mode 100644 index 0000000000000..a2a871cbd3c31 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/SparkSaslClient.scala @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import java.io.IOException +import javax.security.auth.callback.Callback +import javax.security.auth.callback.CallbackHandler +import javax.security.auth.callback.NameCallback +import javax.security.auth.callback.PasswordCallback +import javax.security.auth.callback.UnsupportedCallbackException +import javax.security.sasl.RealmCallback +import javax.security.sasl.RealmChoiceCallback +import javax.security.sasl.Sasl +import javax.security.sasl.SaslClient +import javax.security.sasl.SaslException + +import scala.collection.JavaConversions.mapAsJavaMap + +/** + * Implements SASL Client logic for Spark + */ +private[spark] class SparkSaslClient(securityMgr: SecurityManager) extends Logging { + + /** + * Used to respond to server's counterpart, SaslServer with SASL tokens + * represented as byte arrays. + * + * The authentication mechanism used here is DIGEST-MD5. This could be changed to be + * configurable in the future. + */ + private var saslClient: SaslClient = Sasl.createSaslClient(Array[String](SparkSaslServer.DIGEST), + null, null, SparkSaslServer.SASL_DEFAULT_REALM, SparkSaslServer.SASL_PROPS, + new SparkSaslClientCallbackHandler(securityMgr)) + + /** + * Used to initiate SASL handshake with server. + * @return response to challenge if needed + */ + def firstToken(): Array[Byte] = { + synchronized { + val saslToken: Array[Byte] = + if (saslClient != null && saslClient.hasInitialResponse()) { + logDebug("has initial response") + saslClient.evaluateChallenge(new Array[Byte](0)) + } else { + new Array[Byte](0) + } + saslToken + } + } + + /** + * Determines whether the authentication exchange has completed. + * @return true is complete, otherwise false + */ + def isComplete(): Boolean = { + synchronized { + if (saslClient != null) saslClient.isComplete() else false + } + } + + /** + * Respond to server's SASL token. + * @param saslTokenMessage contains server's SASL token + * @return client's response SASL token + */ + def saslResponse(saslTokenMessage: Array[Byte]): Array[Byte] = { + synchronized { + if (saslClient != null) saslClient.evaluateChallenge(saslTokenMessage) else new Array[Byte](0) + } + } + + /** + * Disposes of any system resources or security-sensitive information the + * SaslClient might be using. + */ + def dispose() { + synchronized { + if (saslClient != null) { + try { + saslClient.dispose() + } catch { + case e: SaslException => // ignored + } finally { + saslClient = null + } + } + } + } + + /** + * Implementation of javax.security.auth.callback.CallbackHandler + * that works with share secrets. + */ + private class SparkSaslClientCallbackHandler(securityMgr: SecurityManager) extends + CallbackHandler { + + private val userName: String = + SparkSaslServer.encodeIdentifier(securityMgr.getSaslUser().getBytes()) + private val secretKey = securityMgr.getSecretKey() + private val userPassword: Array[Char] = + SparkSaslServer.encodePassword(if (secretKey != null) secretKey.getBytes() else "".getBytes()) + + /** + * Implementation used to respond to SASL request from the server. + * + * @param callbacks objects that indicate what credential information the + * server's SaslServer requires from the client. + */ + override def handle(callbacks: Array[Callback]) { + logDebug("in the sasl client callback handler") + callbacks foreach { + case nc: NameCallback => { + logDebug("handle: SASL client callback: setting username: " + userName) + nc.setName(userName) + } + case pc: PasswordCallback => { + logDebug("handle: SASL client callback: setting userPassword") + pc.setPassword(userPassword) + } + case rc: RealmCallback => { + logDebug("handle: SASL client callback: setting realm: " + rc.getDefaultText()) + rc.setText(rc.getDefaultText()) + } + case cb: RealmChoiceCallback => {} + case cb: Callback => throw + new UnsupportedCallbackException(cb, "handle: Unrecognized SASL client callback") + } + } + } +} diff --git a/core/src/main/scala/org/apache/spark/SparkSaslServer.scala b/core/src/main/scala/org/apache/spark/SparkSaslServer.scala new file mode 100644 index 0000000000000..11fcb2ae3a5c5 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/SparkSaslServer.scala @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import javax.security.auth.callback.Callback +import javax.security.auth.callback.CallbackHandler +import javax.security.auth.callback.NameCallback +import javax.security.auth.callback.PasswordCallback +import javax.security.auth.callback.UnsupportedCallbackException +import javax.security.sasl.AuthorizeCallback +import javax.security.sasl.RealmCallback +import javax.security.sasl.Sasl +import javax.security.sasl.SaslException +import javax.security.sasl.SaslServer +import scala.collection.JavaConversions.mapAsJavaMap +import org.apache.commons.net.util.Base64 + +/** + * Encapsulates SASL server logic + */ +private[spark] class SparkSaslServer(securityMgr: SecurityManager) extends Logging { + + /** + * Actual SASL work done by this object from javax.security.sasl. + */ + private var saslServer: SaslServer = Sasl.createSaslServer(SparkSaslServer.DIGEST, null, + SparkSaslServer.SASL_DEFAULT_REALM, SparkSaslServer.SASL_PROPS, + new SparkSaslDigestCallbackHandler(securityMgr)) + + /** + * Determines whether the authentication exchange has completed. + * @return true is complete, otherwise false + */ + def isComplete(): Boolean = { + synchronized { + if (saslServer != null) saslServer.isComplete() else false + } + } + + /** + * Used to respond to server SASL tokens. + * @param token Server's SASL token + * @return response to send back to the server. + */ + def response(token: Array[Byte]): Array[Byte] = { + synchronized { + if (saslServer != null) saslServer.evaluateResponse(token) else new Array[Byte](0) + } + } + + /** + * Disposes of any system resources or security-sensitive information the + * SaslServer might be using. + */ + def dispose() { + synchronized { + if (saslServer != null) { + try { + saslServer.dispose() + } catch { + case e: SaslException => // ignore + } finally { + saslServer = null + } + } + } + } + + /** + * Implementation of javax.security.auth.callback.CallbackHandler + * for SASL DIGEST-MD5 mechanism + */ + private class SparkSaslDigestCallbackHandler(securityMgr: SecurityManager) + extends CallbackHandler { + + private val userName: String = + SparkSaslServer.encodeIdentifier(securityMgr.getSaslUser().getBytes()) + + override def handle(callbacks: Array[Callback]) { + logDebug("In the sasl server callback handler") + callbacks foreach { + case nc: NameCallback => { + logDebug("handle: SASL server callback: setting username") + nc.setName(userName) + } + case pc: PasswordCallback => { + logDebug("handle: SASL server callback: setting userPassword") + val password: Array[Char] = + SparkSaslServer.encodePassword(securityMgr.getSecretKey().getBytes()) + pc.setPassword(password) + } + case rc: RealmCallback => { + logDebug("handle: SASL server callback: setting realm: " + rc.getDefaultText()) + rc.setText(rc.getDefaultText()) + } + case ac: AuthorizeCallback => { + val authid = ac.getAuthenticationID() + val authzid = ac.getAuthorizationID() + if (authid.equals(authzid)) { + logDebug("set auth to true") + ac.setAuthorized(true) + } else { + logDebug("set auth to false") + ac.setAuthorized(false) + } + if (ac.isAuthorized()) { + logDebug("sasl server is authorized") + ac.setAuthorizedID(authzid) + } + } + case cb: Callback => throw + new UnsupportedCallbackException(cb, "handle: Unrecognized SASL DIGEST-MD5 Callback") + } + } + } +} + +private[spark] object SparkSaslServer { + + /** + * This is passed as the server name when creating the sasl client/server. + * This could be changed to be configurable in the future. + */ + val SASL_DEFAULT_REALM = "default" + + /** + * The authentication mechanism used here is DIGEST-MD5. This could be changed to be + * configurable in the future. + */ + val DIGEST = "DIGEST-MD5" + + /** + * The quality of protection is just "auth". This means that we are doing + * authentication only, we are not supporting integrity or privacy protection of the + * communication channel after authentication. This could be changed to be configurable + * in the future. + */ + val SASL_PROPS = Map(Sasl.QOP -> "auth", Sasl.SERVER_AUTH ->"true") + + /** + * Encode a byte[] identifier as a Base64-encoded string. + * + * @param identifier identifier to encode + * @return Base64-encoded string + */ + def encodeIdentifier(identifier: Array[Byte]): String = { + new String(Base64.encodeBase64(identifier)) + } + + /** + * Encode a password as a base64-encoded char[] array. + * @param password as a byte array. + * @return password as a char array. + */ + def encodePassword(password: Array[Byte]): Array[Char] = { + new String(Base64.encodeBase64(password)).toCharArray() + } +} + diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala index d1787061bc642..f816bb43a5b44 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala @@ -140,6 +140,14 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[JDouble, Ja */ def union(other: JavaDoubleRDD): JavaDoubleRDD = fromRDD(srdd.union(other.srdd)) + /** + * Return the intersection of this RDD and another one. The output will not contain any duplicate + * elements, even if the input RDDs did. + * + * Note that this method performs a shuffle internally. + */ + def intersection(other: JavaDoubleRDD): JavaDoubleRDD = fromRDD(srdd.intersection(other.srdd)) + // Double RDD functions /** Add up the elements in this RDD. */ diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index 857626fe84af9..0ff428c120353 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -126,6 +126,16 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) def union(other: JavaPairRDD[K, V]): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.union(other.rdd)) + /** + * Return the intersection of this RDD and another one. The output will not contain any duplicate + * elements, even if the input RDDs did. + * + * Note that this method performs a shuffle internally. + */ + def intersection(other: JavaPairRDD[K, V]): JavaPairRDD[K, V] = + new JavaPairRDD[K, V](rdd.intersection(other.rdd)) + + // first() has to be overridden here so that the generated method has the signature // 'public scala.Tuple2 first()'; if the trait's definition is used, // then the method has the signature 'public java.lang.Object first()', diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala index e973c46edd1ce..91bf404631f49 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala @@ -106,6 +106,15 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) */ def union(other: JavaRDD[T]): JavaRDD[T] = wrapRDD(rdd.union(other.rdd)) + + /** + * Return the intersection of this RDD and another one. The output will not contain any duplicate + * elements, even if the input RDDs did. + * + * Note that this method performs a shuffle internally. + */ + def intersection(other: JavaRDD[T]): JavaRDD[T] = wrapRDD(rdd.intersection(other.rdd)) + /** * Return an RDD with the elements from `this` that are not in `other`. * diff --git a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala index d113d4040594d..e3c3a12d16f2a 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala @@ -60,7 +60,8 @@ abstract class Broadcast[T](val id: Long) extends Serializable { } private[spark] -class BroadcastManager(val _isDriver: Boolean, conf: SparkConf) extends Logging with Serializable { +class BroadcastManager(val _isDriver: Boolean, conf: SparkConf, securityManager: SecurityManager) + extends Logging with Serializable { private var initialized = false private var broadcastFactory: BroadcastFactory = null @@ -78,7 +79,7 @@ class BroadcastManager(val _isDriver: Boolean, conf: SparkConf) extends Logging Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory] // Initialize appropriate BroadcastFactory and BroadcastObject - broadcastFactory.initialize(isDriver, conf) + broadcastFactory.initialize(isDriver, conf, securityManager) initialized = true } diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala index 940e5ab805100..6beecaeced5be 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala @@ -16,6 +16,7 @@ */ package org.apache.spark.broadcast +import org.apache.spark.SecurityManager import org.apache.spark.SparkConf @@ -26,7 +27,7 @@ import org.apache.spark.SparkConf * entire Spark job. */ trait BroadcastFactory { - def initialize(isDriver: Boolean, conf: SparkConf): Unit + def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager): Unit def newBroadcast[T](value: T, isLocal: Boolean, id: Long): Broadcast[T] def stop(): Unit } diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index 20207c261320b..e8eb04bb10469 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -18,13 +18,13 @@ package org.apache.spark.broadcast import java.io.{File, FileOutputStream, ObjectInputStream, OutputStream} -import java.net.URL +import java.net.{URL, URLConnection, URI} import java.util.concurrent.TimeUnit import it.unimi.dsi.fastutil.io.FastBufferedInputStream import it.unimi.dsi.fastutil.io.FastBufferedOutputStream -import org.apache.spark.{HttpServer, Logging, SparkConf, SparkEnv} +import org.apache.spark.{SparkConf, HttpServer, Logging, SecurityManager, SparkEnv} import org.apache.spark.io.CompressionCodec import org.apache.spark.storage.{BroadcastBlockId, StorageLevel} import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashSet, Utils} @@ -67,7 +67,9 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea * A [[BroadcastFactory]] implementation that uses a HTTP server as the broadcast medium. */ class HttpBroadcastFactory extends BroadcastFactory { - def initialize(isDriver: Boolean, conf: SparkConf) { HttpBroadcast.initialize(isDriver, conf) } + def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { + HttpBroadcast.initialize(isDriver, conf, securityMgr) + } def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = new HttpBroadcast[T](value_, isLocal, id) @@ -83,6 +85,7 @@ private object HttpBroadcast extends Logging { private var bufferSize: Int = 65536 private var serverUri: String = null private var server: HttpServer = null + private var securityManager: SecurityManager = null // TODO: This shouldn't be a global variable so that multiple SparkContexts can coexist private val files = new TimeStampedHashSet[String] @@ -92,11 +95,12 @@ private object HttpBroadcast extends Logging { private var compressionCodec: CompressionCodec = null - def initialize(isDriver: Boolean, conf: SparkConf) { + def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { synchronized { if (!initialized) { bufferSize = conf.getInt("spark.buffer.size", 65536) compress = conf.getBoolean("spark.broadcast.compress", true) + securityManager = securityMgr if (isDriver) { createServer(conf) conf.set("spark.httpBroadcast.uri", serverUri) @@ -126,7 +130,7 @@ private object HttpBroadcast extends Logging { private def createServer(conf: SparkConf) { broadcastDir = Utils.createTempDir(Utils.getLocalDir(conf)) - server = new HttpServer(broadcastDir) + server = new HttpServer(broadcastDir, securityManager) server.start() serverUri = server.uri logInfo("Broadcast server started at " + serverUri) @@ -149,11 +153,23 @@ private object HttpBroadcast extends Logging { } def read[T](id: Long): T = { + logDebug("broadcast read server: " + serverUri + " id: broadcast-" + id) val url = serverUri + "/" + BroadcastBlockId(id).name + + var uc: URLConnection = null + if (securityManager.isAuthenticationEnabled()) { + logDebug("broadcast security enabled") + val newuri = Utils.constructURIForAuthentication(new URI(url), securityManager) + uc = newuri.toURL().openConnection() + uc.setAllowUserInteraction(false) + } else { + logDebug("broadcast not using security") + uc = new URL(url).openConnection() + } + val in = { - val httpConnection = new URL(url).openConnection() - httpConnection.setReadTimeout(httpReadTimeout) - val inputStream = httpConnection.getInputStream + uc.setReadTimeout(httpReadTimeout) + val inputStream = uc.getInputStream(); if (compress) { compressionCodec.compressedInputStream(inputStream) } else { diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 22d783c8590c6..3cd71213769b7 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -241,7 +241,9 @@ private[spark] case class TorrentInfo( */ class TorrentBroadcastFactory extends BroadcastFactory { - def initialize(isDriver: Boolean, conf: SparkConf) { TorrentBroadcast.initialize(isDriver, conf) } + def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { + TorrentBroadcast.initialize(isDriver, conf) + } def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = new TorrentBroadcast[T](value_, isLocal, id) diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index eb5676b51d836..d9e3035e1ab59 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -26,7 +26,7 @@ import akka.pattern.ask import akka.remote.{AssociationErrorEvent, DisassociatedEvent, RemotingLifecycleEvent} import org.apache.log4j.{Level, Logger} -import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.{DriverState, Master} import org.apache.spark.util.{AkkaUtils, Utils} @@ -141,7 +141,7 @@ object Client { // TODO: See if we can initialize akka so return messages are sent back using the same TCP // flow. Else, this (sadly) requires the DriverClient be routable from the Master. val (actorSystem, _) = AkkaUtils.createActorSystem( - "driverClient", Utils.localHostName(), 0, false, conf) + "driverClient", Utils.localHostName(), 0, false, conf, new SecurityManager(conf)) actorSystem.actorOf(Props(classOf[ClientActor], driverArgs, conf)) diff --git a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala index d48c1892aea9c..f4eb1601be3e4 100644 --- a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala +++ b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala @@ -30,20 +30,24 @@ import scala.sys.process._ import org.json4s._ import org.json4s.jackson.JsonMethods -import org.apache.spark.{Logging, SparkContext} -import org.apache.spark.deploy.master.RecoveryState +import org.apache.spark.{Logging, SparkConf, SparkContext} +import org.apache.spark.deploy.master.{RecoveryState, SparkCuratorUtil} /** * This suite tests the fault tolerance of the Spark standalone scheduler, mainly the Master. * In order to mimic a real distributed cluster more closely, Docker is used. * Execute using - * ./spark-class org.apache.spark.deploy.FaultToleranceTest + * ./bin/spark-class org.apache.spark.deploy.FaultToleranceTest * - * Make sure that that the environment includes the following properties in SPARK_DAEMON_JAVA_OPTS: + * Make sure that that the environment includes the following properties in SPARK_DAEMON_JAVA_OPTS + * *and* SPARK_JAVA_OPTS: * - spark.deploy.recoveryMode=ZOOKEEPER * - spark.deploy.zookeeper.url=172.17.42.1:2181 * Note that 172.17.42.1 is the default docker ip for the host and 2181 is the default ZK port. * + * In case of failure, make sure to kill off prior docker containers before restarting: + * docker kill $(docker ps -q) + * * Unfortunately, due to the Docker dependency this suite cannot be run automatically without a * working installation of Docker. In addition to having Docker, the following are assumed: * - Docker can run without sudo (see http://docs.docker.io/en/latest/use/basics/) @@ -51,10 +55,16 @@ import org.apache.spark.deploy.master.RecoveryState * docker/ directory. Run 'docker/spark-test/build' to generate these. */ private[spark] object FaultToleranceTest extends App with Logging { + + val conf = new SparkConf() + val ZK_DIR = conf.get("spark.deploy.zookeeper.dir", "/spark") + val masters = ListBuffer[TestMasterInfo]() val workers = ListBuffer[TestWorkerInfo]() var sc: SparkContext = _ + val zk = SparkCuratorUtil.newClient(conf) + var numPassed = 0 var numFailed = 0 @@ -72,6 +82,10 @@ private[spark] object FaultToleranceTest extends App with Logging { sc = null } terminateCluster() + + // Clear ZK directories in between tests (for speed purposes) + SparkCuratorUtil.deleteRecursive(zk, ZK_DIR + "/spark_leader") + SparkCuratorUtil.deleteRecursive(zk, ZK_DIR + "/master_status") } test("sanity-basic") { @@ -168,26 +182,34 @@ private[spark] object FaultToleranceTest extends App with Logging { try { fn numPassed += 1 + logInfo("==============================================") logInfo("Passed: " + name) + logInfo("==============================================") } catch { case e: Exception => numFailed += 1 + logInfo("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") logError("FAILED: " + name, e) + logInfo("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") + sys.exit(1) } afterEach() } def addMasters(num: Int) { + logInfo(s">>>>> ADD MASTERS $num <<<<<") (1 to num).foreach { _ => masters += SparkDocker.startMaster(dockerMountDir) } } def addWorkers(num: Int) { + logInfo(s">>>>> ADD WORKERS $num <<<<<") val masterUrls = getMasterUrls(masters) (1 to num).foreach { _ => workers += SparkDocker.startWorker(dockerMountDir, masterUrls) } } /** Creates a SparkContext, which constructs a Client to interact with our cluster. */ def createClient() = { + logInfo(">>>>> CREATE CLIENT <<<<<") if (sc != null) { sc.stop() } // Counter-hack: Because of a hack in SparkEnv#create() that changes this // property, we need to reset it. @@ -206,6 +228,7 @@ private[spark] object FaultToleranceTest extends App with Logging { } def killLeader(): Unit = { + logInfo(">>>>> KILL LEADER <<<<<") masters.foreach(_.readState()) val leader = getLeader masters -= leader @@ -215,6 +238,7 @@ private[spark] object FaultToleranceTest extends App with Logging { def delay(secs: Duration = 5.seconds) = Thread.sleep(secs.toMillis) def terminateCluster() { + logInfo(">>>>> TERMINATE CLUSTER <<<<<") masters.foreach(_.kill()) workers.foreach(_.kill()) masters.clear() @@ -245,6 +269,7 @@ private[spark] object FaultToleranceTest extends App with Logging { * are all alive in a proper configuration (e.g., only one leader). */ def assertValidClusterState() = { + logInfo(">>>>> ASSERT VALID CLUSTER STATE <<<<<") assertUsable() var numAlive = 0 var numStandby = 0 @@ -326,7 +351,11 @@ private[spark] class TestMasterInfo(val ip: String, val dockerId: DockerId, val val workers = json \ "workers" val liveWorkers = workers.children.filter(w => (w \ "state").extract[String] == "ALIVE") - liveWorkerIPs = liveWorkers.map(w => (w \ "host").extract[String]) + // Extract the worker IP from "webuiaddress" (rather than "host") because the host name + // on containers is a weird hash instead of the actual IP address. + liveWorkerIPs = liveWorkers.map { + w => (w \ "webuiaddress").extract[String].stripPrefix("http://").stripSuffix(":8081") + } numLiveApps = (json \ "activeapps").children.size @@ -403,7 +432,7 @@ private[spark] object Docker extends Logging { def makeRunCmd(imageTag: String, args: String = "", mountDir: String = ""): ProcessBuilder = { val mountCmd = if (mountDir != "") { " -v " + mountDir } else "" - val cmd = "docker run %s %s %s".format(mountCmd, imageTag, args) + val cmd = "docker run -privileged %s %s %s".format(mountCmd, imageTag, args) logDebug("Run command: " + cmd) cmd } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index ec15647e1d9eb..d2d8d6d662d55 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -21,6 +21,7 @@ import java.security.PrivilegedExceptionAction import org.apache.hadoop.conf.Configuration import org.apache.hadoop.mapred.JobConf +import org.apache.hadoop.security.Credentials import org.apache.hadoop.security.UserGroupInformation import org.apache.spark.{SparkContext, SparkException} @@ -65,6 +66,15 @@ class SparkHadoopUtil { def addCredentials(conf: JobConf) {} def isYarnMode(): Boolean = { false } + + def getCurrentUserCredentials(): Credentials = { null } + + def addCurrentUserCredentials(creds: Credentials) {} + + def addSecretKeyToUserCredentials(key: String, secret: String) {} + + def getSecretKeyFromUserCredentials(key: String): Array[Byte] = { null } + } object SparkHadoopUtil { diff --git a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala index 1550c3eb4286b..63f166d401059 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala @@ -17,7 +17,7 @@ package org.apache.spark.deploy.client -import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.{SecurityManager, SparkConf, Logging} import org.apache.spark.deploy.{ApplicationDescription, Command} import org.apache.spark.util.{AkkaUtils, Utils} @@ -45,8 +45,9 @@ private[spark] object TestClient { def main(args: Array[String]) { val url = args(0) + val conf = new SparkConf val (actorSystem, port) = AkkaUtils.createActorSystem("spark", Utils.localIpAddress, 0, - conf = new SparkConf) + conf = conf, securityManager = new SecurityManager(conf)) val desc = new ApplicationDescription( "TestClient", Some(1), 512, Command("spark.deploy.client.TestExecutor", Seq(), Map()), Some("dummy-spark-home"), "ignored") diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 51794ce40cb45..b8dfa44102583 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -30,7 +30,7 @@ import akka.pattern.ask import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} import akka.serialization.SerializationExtension -import org.apache.spark.{Logging, SparkConf, SparkException} +import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.{ApplicationDescription, DriverDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.DriverState.DriverState @@ -39,7 +39,8 @@ import org.apache.spark.deploy.master.ui.MasterWebUI import org.apache.spark.metrics.MetricsSystem import org.apache.spark.util.{AkkaUtils, Utils} -private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Actor with Logging { +private[spark] class Master(host: String, port: Int, webUiPort: Int, + val securityMgr: SecurityManager) extends Actor with Logging { import context.dispatcher // to use Akka's scheduler.schedule() val conf = new SparkConf @@ -70,8 +71,9 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act Utils.checkHost(host, "Expected hostname") - val masterMetricsSystem = MetricsSystem.createMetricsSystem("master", conf) - val applicationMetricsSystem = MetricsSystem.createMetricsSystem("applications", conf) + val masterMetricsSystem = MetricsSystem.createMetricsSystem("master", conf, securityMgr) + val applicationMetricsSystem = MetricsSystem.createMetricsSystem("applications", conf, + securityMgr) val masterSource = new MasterSource(this) val webUi = new MasterWebUI(this, webUiPort) @@ -529,8 +531,15 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act val workerAddress = worker.actor.path.address if (addressToWorker.contains(workerAddress)) { - logInfo("Attempted to re-register worker at same address: " + workerAddress) - return false + val oldWorker = addressToWorker(workerAddress) + if (oldWorker.state == WorkerState.UNKNOWN) { + // A worker registering from UNKNOWN implies that the worker was restarted during recovery. + // The old worker must thus be dead, so we will remove it and accept the new worker. + removeWorker(oldWorker) + } else { + logInfo("Attempted to re-register worker at same address: " + workerAddress) + return false + } } workers += worker @@ -711,8 +720,11 @@ private[spark] object Master { def startSystemAndActor(host: String, port: Int, webUiPort: Int, conf: SparkConf) : (ActorSystem, Int, Int) = { - val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port, conf = conf) - val actor = actorSystem.actorOf(Props(classOf[Master], host, boundPort, webUiPort), actorName) + val securityMgr = new SecurityManager(conf) + val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port, conf = conf, + securityManager = securityMgr) + val actor = actorSystem.actorOf(Props(classOf[Master], host, boundPort, webUiPort, + securityMgr), actorName) val timeout = AkkaUtils.askTimeout(conf) val respFuture = actor.ask(RequestWebUIPort)(timeout) val resp = Await.result(respFuture, timeout).asInstanceOf[WebUIPortResponse] diff --git a/core/src/main/scala/org/apache/spark/deploy/master/SparkCuratorUtil.scala b/core/src/main/scala/org/apache/spark/deploy/master/SparkCuratorUtil.scala index 2d35397035a03..4781a80d470e1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/SparkCuratorUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/SparkCuratorUtil.scala @@ -17,11 +17,13 @@ package org.apache.spark.deploy.master -import org.apache.spark.{SparkConf, Logging} +import scala.collection.JavaConversions._ + import org.apache.curator.framework.{CuratorFramework, CuratorFrameworkFactory} import org.apache.curator.retry.ExponentialBackoffRetry import org.apache.zookeeper.KeeperException +import org.apache.spark.{Logging, SparkConf} object SparkCuratorUtil extends Logging { @@ -50,4 +52,13 @@ object SparkCuratorUtil extends Logging { } } } + + def deleteRecursive(zk: CuratorFramework, path: String) { + if (zk.checkExists().forPath(path) != null) { + for (child <- zk.getChildren.forPath(path)) { + zk.delete().forPath(path + "/" + child) + } + zk.delete().forPath(path) + } + } } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala index 5ab13e7aa6b1f..a7bd01e284c8e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala @@ -18,8 +18,8 @@ package org.apache.spark.deploy.master.ui import javax.servlet.http.HttpServletRequest - -import org.eclipse.jetty.server.{Handler, Server} +import org.eclipse.jetty.server.Server +import org.eclipse.jetty.servlet.ServletContextHandler import org.apache.spark.Logging import org.apache.spark.deploy.master.Master @@ -46,7 +46,7 @@ class MasterWebUI(val master: Master, requestedPort: Int) extends Logging { def start() { try { - val (srv, bPort) = JettyUtils.startJettyServer(host, port, handlers) + val (srv, bPort) = JettyUtils.startJettyServer(host, port, handlers, master.conf) server = Some(srv) boundPort = Some(bPort) logInfo("Started Master web UI at http://%s:%d".format(host, boundPort.get)) @@ -60,12 +60,17 @@ class MasterWebUI(val master: Master, requestedPort: Int) extends Logging { val metricsHandlers = master.masterMetricsSystem.getServletHandlers ++ master.applicationMetricsSystem.getServletHandlers - val handlers = metricsHandlers ++ Array[(String, Handler)]( - ("/static", createStaticHandler(MasterWebUI.STATIC_RESOURCE_DIR)), - ("/app/json", (request: HttpServletRequest) => applicationPage.renderJson(request)), - ("/app", (request: HttpServletRequest) => applicationPage.render(request)), - ("/json", (request: HttpServletRequest) => indexPage.renderJson(request)), - ("*", (request: HttpServletRequest) => indexPage.render(request)) + val handlers = metricsHandlers ++ Seq[ServletContextHandler]( + createStaticHandler(MasterWebUI.STATIC_RESOURCE_DIR, "/static/*"), + createServletHandler("/app/json", + createServlet((request: HttpServletRequest) => applicationPage.renderJson(request), + master.securityMgr)), + createServletHandler("/app", createServlet((request: HttpServletRequest) => applicationPage + .render(request), master.securityMgr)), + createServletHandler("/json", createServlet((request: HttpServletRequest) => indexPage + .renderJson(request), master.securityMgr)), + createServletHandler("*", createServlet((request: HttpServletRequest) => indexPage.render + (request), master.securityMgr)) ) def stop() { @@ -74,5 +79,5 @@ class MasterWebUI(val master: Master, requestedPort: Int) extends Logging { } private[spark] object MasterWebUI { - val STATIC_RESOURCE_DIR = "org/apache/spark/ui/static" + val STATIC_RESOURCE_DIR = "org/apache/spark/ui" } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala index a26e47950a0ec..be15138f62406 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala @@ -19,7 +19,7 @@ package org.apache.spark.deploy.worker import akka.actor._ -import org.apache.spark.SparkConf +import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.util.{AkkaUtils, Utils} /** @@ -29,8 +29,9 @@ object DriverWrapper { def main(args: Array[String]) { args.toList match { case workerUrl :: mainClass :: extraArgs => + val conf = new SparkConf() val (actorSystem, _) = AkkaUtils.createActorSystem("Driver", - Utils.localHostName(), 0, false, new SparkConf()) + Utils.localHostName(), 0, false, conf, new SecurityManager(conf)) actorSystem.actorOf(Props(classOf[WorkerWatcher], workerUrl), name = "workerWatcher") // Delegate to supplied main class diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 7b0b7861b76e1..afaabedffefea 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -27,7 +27,7 @@ import scala.concurrent.duration._ import akka.actor._ import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} -import org.apache.spark.{Logging, SparkConf, SparkException} +import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.{ExecutorDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.{DriverState, Master} @@ -48,7 +48,8 @@ private[spark] class Worker( actorSystemName: String, actorName: String, workDirPath: String = null, - val conf: SparkConf) + val conf: SparkConf, + val securityMgr: SecurityManager) extends Actor with Logging { import context.dispatcher @@ -91,7 +92,7 @@ private[spark] class Worker( var coresUsed = 0 var memoryUsed = 0 - val metricsSystem = MetricsSystem.createMetricsSystem("worker", conf) + val metricsSystem = MetricsSystem.createMetricsSystem("worker", conf, securityMgr) val workerSource = new WorkerSource(this) def coresFree: Int = cores - coresUsed @@ -347,10 +348,11 @@ private[spark] object Worker { val conf = new SparkConf val systemName = "sparkWorker" + workerNumber.map(_.toString).getOrElse("") val actorName = "Worker" + val securityMgr = new SecurityManager(conf) val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port, - conf = conf) + conf = conf, securityManager = securityMgr) actorSystem.actorOf(Props(classOf[Worker], host, boundPort, webUiPort, cores, memory, - masterUrls, systemName, actorName, workDir, conf), name = actorName) + masterUrls, systemName, actorName, workDir, conf, securityMgr), name = actorName) (actorSystem, boundPort) } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala index bdf126f93abc8..ffc05bd30687a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala @@ -19,8 +19,8 @@ package org.apache.spark.deploy.worker.ui import java.io.File import javax.servlet.http.HttpServletRequest - -import org.eclipse.jetty.server.{Handler, Server} +import org.eclipse.jetty.server.Server +import org.eclipse.jetty.servlet.ServletContextHandler import org.apache.spark.Logging import org.apache.spark.deploy.worker.Worker @@ -33,7 +33,7 @@ import org.apache.spark.util.{AkkaUtils, Utils} */ private[spark] class WorkerWebUI(val worker: Worker, val workDir: File, requestedPort: Option[Int] = None) - extends Logging { + extends Logging { val timeout = AkkaUtils.askTimeout(worker.conf) val host = Utils.localHostName() val port = requestedPort.getOrElse( @@ -46,17 +46,21 @@ class WorkerWebUI(val worker: Worker, val workDir: File, requestedPort: Option[I val metricsHandlers = worker.metricsSystem.getServletHandlers - val handlers = metricsHandlers ++ Array[(String, Handler)]( - ("/static", createStaticHandler(WorkerWebUI.STATIC_RESOURCE_DIR)), - ("/log", (request: HttpServletRequest) => log(request)), - ("/logPage", (request: HttpServletRequest) => logPage(request)), - ("/json", (request: HttpServletRequest) => indexPage.renderJson(request)), - ("*", (request: HttpServletRequest) => indexPage.render(request)) + val handlers = metricsHandlers ++ Seq[ServletContextHandler]( + createStaticHandler(WorkerWebUI.STATIC_RESOURCE_BASE, "/static/*"), + createServletHandler("/log", createServlet((request: HttpServletRequest) => log(request), + worker.securityMgr)), + createServletHandler("/logPage", createServlet((request: HttpServletRequest) => logPage + (request), worker.securityMgr)), + createServletHandler("/json", createServlet((request: HttpServletRequest) => indexPage + .renderJson(request), worker.securityMgr)), + createServletHandler("*", createServlet((request: HttpServletRequest) => indexPage.render + (request), worker.securityMgr)) ) def start() { try { - val (srv, bPort) = JettyUtils.startJettyServer(host, port, handlers) + val (srv, bPort) = JettyUtils.startJettyServer(host, port, handlers, worker.conf) server = Some(srv) boundPort = Some(bPort) logInfo("Started Worker web UI at http://%s:%d".format(host, bPort)) @@ -198,6 +202,6 @@ class WorkerWebUI(val worker: Worker, val workDir: File, requestedPort: Option[I } private[spark] object WorkerWebUI { - val STATIC_RESOURCE_DIR = "org/apache/spark/ui/static" + val STATIC_RESOURCE_BASE = "org/apache/spark/ui" val DEFAULT_PORT="8081" } diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 0aae569b17272..3486092a140fb 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -22,7 +22,7 @@ import java.nio.ByteBuffer import akka.actor._ import akka.remote._ -import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.{SecurityManager, SparkConf, Logging} import org.apache.spark.TaskState.TaskState import org.apache.spark.deploy.worker.WorkerWatcher import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ @@ -97,10 +97,11 @@ private[spark] object CoarseGrainedExecutorBackend { // Debug code Utils.checkHost(hostname) + val conf = new SparkConf // Create a new ActorSystem to run the backend, because we can't create a SparkEnv / Executor // before getting started with all our system properties, etc val (actorSystem, boundPort) = AkkaUtils.createActorSystem("sparkExecutor", hostname, 0, - indestructible = true, conf = new SparkConf) + indestructible = true, conf = conf, new SecurityManager(conf)) // set it val sparkHostPort = hostname + ":" + boundPort actorSystem.actorOf( diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 989d666f15600..e69f6f72d3275 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -69,11 +69,6 @@ private[spark] class Executor( conf.set("spark.local.dir", getYarnLocalDirs()) } - // Create our ClassLoader and set it on this thread - private val urlClassLoader = createClassLoader() - private val replClassLoader = addReplClassLoaderIfNeeded(urlClassLoader) - Thread.currentThread.setContextClassLoader(replClassLoader) - if (!isLocal) { // Setup an uncaught exception handler for non-local mode. // Make any thread terminations due to uncaught exceptions kill the entire @@ -117,6 +112,12 @@ private[spark] class Executor( } } + // Create our ClassLoader and set it on this thread + // do this after SparkEnv creation so can access the SecurityManager + private val urlClassLoader = createClassLoader() + private val replClassLoader = addReplClassLoaderIfNeeded(urlClassLoader) + Thread.currentThread.setContextClassLoader(replClassLoader) + // Akka's message frame size. If task result is bigger than this, we use the block manager // to send the result back. private val akkaFrameSize = { @@ -338,12 +339,12 @@ private[spark] class Executor( // Fetch missing dependencies for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) { logInfo("Fetching " + name + " with timestamp " + timestamp) - Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf) + Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf, env.securityManager) currentFiles(name) = timestamp } for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) { logInfo("Fetching " + name + " with timestamp " + timestamp) - Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf) + Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf, env.securityManager) currentJars(name) = timestamp // Add it to our class loader val localName = name.split("/").last diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala index 966c092124266..c5bda2078fc14 100644 --- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala @@ -24,7 +24,7 @@ import scala.collection.mutable import com.codahale.metrics.{Metric, MetricFilter, MetricRegistry} -import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.metrics.sink.{MetricsServlet, Sink} import org.apache.spark.metrics.source.Source @@ -64,7 +64,7 @@ import org.apache.spark.metrics.source.Source * [options] is the specific property of this source or sink. */ private[spark] class MetricsSystem private (val instance: String, - conf: SparkConf) extends Logging { + conf: SparkConf, securityMgr: SecurityManager) extends Logging { val confFile = conf.get("spark.metrics.conf", null) val metricsConfig = new MetricsConfig(Option(confFile)) @@ -131,8 +131,8 @@ private[spark] class MetricsSystem private (val instance: String, val classPath = kv._2.getProperty("class") try { val sink = Class.forName(classPath) - .getConstructor(classOf[Properties], classOf[MetricRegistry]) - .newInstance(kv._2, registry) + .getConstructor(classOf[Properties], classOf[MetricRegistry], classOf[SecurityManager]) + .newInstance(kv._2, registry, securityMgr) if (kv._1 == "servlet") { metricsServlet = Some(sink.asInstanceOf[MetricsServlet]) } else { @@ -160,6 +160,7 @@ private[spark] object MetricsSystem { } } - def createMetricsSystem(instance: String, conf: SparkConf): MetricsSystem = - new MetricsSystem(instance, conf) + def createMetricsSystem(instance: String, conf: SparkConf, + securityMgr: SecurityManager): MetricsSystem = + new MetricsSystem(instance, conf, securityMgr) } diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala index 98fa1dbd7c6ab..4d2ffc54d8983 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala @@ -22,9 +22,11 @@ import java.util.concurrent.TimeUnit import com.codahale.metrics.{ConsoleReporter, MetricRegistry} +import org.apache.spark.SecurityManager import org.apache.spark.metrics.MetricsSystem -class ConsoleSink(val property: Properties, val registry: MetricRegistry) extends Sink { +class ConsoleSink(val property: Properties, val registry: MetricRegistry, + securityMgr: SecurityManager) extends Sink { val CONSOLE_DEFAULT_PERIOD = 10 val CONSOLE_DEFAULT_UNIT = "SECONDS" diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala index 40f64768e6885..319f40815d65f 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala @@ -23,9 +23,11 @@ import java.util.concurrent.TimeUnit import com.codahale.metrics.{CsvReporter, MetricRegistry} +import org.apache.spark.SecurityManager import org.apache.spark.metrics.MetricsSystem -class CsvSink(val property: Properties, val registry: MetricRegistry) extends Sink { +class CsvSink(val property: Properties, val registry: MetricRegistry, + securityMgr: SecurityManager) extends Sink { val CSV_KEY_PERIOD = "period" val CSV_KEY_UNIT = "unit" val CSV_KEY_DIR = "directory" diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/GangliaSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/GangliaSink.scala index 410ca0704b5c4..cd37317da77de 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/GangliaSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/GangliaSink.scala @@ -24,9 +24,11 @@ import com.codahale.metrics.MetricRegistry import com.codahale.metrics.ganglia.GangliaReporter import info.ganglia.gmetric4j.gmetric.GMetric +import org.apache.spark.SecurityManager import org.apache.spark.metrics.MetricsSystem -class GangliaSink(val property: Properties, val registry: MetricRegistry) extends Sink { +class GangliaSink(val property: Properties, val registry: MetricRegistry, + securityMgr: SecurityManager) extends Sink { val GANGLIA_KEY_PERIOD = "period" val GANGLIA_DEFAULT_PERIOD = 10 diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala index e09be001421fc..0ffdf3846dc4a 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala @@ -24,9 +24,11 @@ import java.util.concurrent.TimeUnit import com.codahale.metrics.MetricRegistry import com.codahale.metrics.graphite.{Graphite, GraphiteReporter} +import org.apache.spark.SecurityManager import org.apache.spark.metrics.MetricsSystem -class GraphiteSink(val property: Properties, val registry: MetricRegistry) extends Sink { +class GraphiteSink(val property: Properties, val registry: MetricRegistry, + securityMgr: SecurityManager) extends Sink { val GRAPHITE_DEFAULT_PERIOD = 10 val GRAPHITE_DEFAULT_UNIT = "SECONDS" val GRAPHITE_DEFAULT_PREFIX = "" diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala index b5cf210af2119..3b5edd5c376f0 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala @@ -20,8 +20,11 @@ package org.apache.spark.metrics.sink import java.util.Properties import com.codahale.metrics.{JmxReporter, MetricRegistry} +import org.apache.spark.SecurityManager + +class JmxSink(val property: Properties, val registry: MetricRegistry, + securityMgr: SecurityManager) extends Sink { -class JmxSink(val property: Properties, val registry: MetricRegistry) extends Sink { val reporter: JmxReporter = JmxReporter.forRegistry(registry).build() override def start() { diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala b/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala index 3cdfe26d40f66..3110eccdee4fc 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala @@ -19,16 +19,19 @@ package org.apache.spark.metrics.sink import java.util.Properties import java.util.concurrent.TimeUnit + import javax.servlet.http.HttpServletRequest import com.codahale.metrics.MetricRegistry import com.codahale.metrics.json.MetricsModule import com.fasterxml.jackson.databind.ObjectMapper -import org.eclipse.jetty.server.Handler +import org.eclipse.jetty.servlet.ServletContextHandler +import org.apache.spark.SecurityManager import org.apache.spark.ui.JettyUtils -class MetricsServlet(val property: Properties, val registry: MetricRegistry) extends Sink { +class MetricsServlet(val property: Properties, val registry: MetricRegistry, + securityMgr: SecurityManager) extends Sink { val SERVLET_KEY_PATH = "path" val SERVLET_KEY_SAMPLE = "sample" @@ -42,8 +45,11 @@ class MetricsServlet(val property: Properties, val registry: MetricRegistry) ext val mapper = new ObjectMapper().registerModule( new MetricsModule(TimeUnit.SECONDS, TimeUnit.MILLISECONDS, servletShowSample)) - def getHandlers = Array[(String, Handler)]( - (servletPath, JettyUtils.createHandler(request => getMetricsSnapshot(request), "text/json")) + def getHandlers = Array[ServletContextHandler]( + JettyUtils.createServletHandler(servletPath, + JettyUtils.createServlet( + new JettyUtils.ServletParams(request => getMetricsSnapshot(request), "text/json"), + securityMgr) ) ) def getMetricsSnapshot(request: HttpServletRequest): String = { diff --git a/core/src/main/scala/org/apache/spark/network/BufferMessage.scala b/core/src/main/scala/org/apache/spark/network/BufferMessage.scala index d3c09b16063d6..04df2f3b0d696 100644 --- a/core/src/main/scala/org/apache/spark/network/BufferMessage.scala +++ b/core/src/main/scala/org/apache/spark/network/BufferMessage.scala @@ -45,9 +45,10 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: throw new Exception("Max chunk size is " + maxChunkSize) } + val security = if (isSecurityNeg) 1 else 0 if (size == 0 && !gotChunkForSendingOnce) { val newChunk = new MessageChunk( - new MessageChunkHeader(typ, id, 0, 0, ackId, senderAddress), null) + new MessageChunkHeader(typ, id, 0, 0, ackId, security, senderAddress), null) gotChunkForSendingOnce = true return Some(newChunk) } @@ -65,7 +66,7 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: } buffer.position(buffer.position + newBuffer.remaining) val newChunk = new MessageChunk(new MessageChunkHeader( - typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer) + typ, id, size, newBuffer.remaining, ackId, security, senderAddress), newBuffer) gotChunkForSendingOnce = true return Some(newChunk) } @@ -79,6 +80,7 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: throw new Exception("Attempting to get chunk from message with multiple data buffers") } val buffer = buffers(0) + val security = if (isSecurityNeg) 1 else 0 if (buffer.remaining > 0) { if (buffer.remaining < chunkSize) { throw new Exception("Not enough space in data buffer for receiving chunk") @@ -86,7 +88,7 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: val newBuffer = buffer.slice().limit(chunkSize).asInstanceOf[ByteBuffer] buffer.position(buffer.position + newBuffer.remaining) val newChunk = new MessageChunk(new MessageChunkHeader( - typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer) + typ, id, size, newBuffer.remaining, ackId, security, senderAddress), newBuffer) return Some(newChunk) } None diff --git a/core/src/main/scala/org/apache/spark/network/Connection.scala b/core/src/main/scala/org/apache/spark/network/Connection.scala index 8219a185ea983..8fd9c2b87d256 100644 --- a/core/src/main/scala/org/apache/spark/network/Connection.scala +++ b/core/src/main/scala/org/apache/spark/network/Connection.scala @@ -17,6 +17,11 @@ package org.apache.spark.network +import org.apache.spark._ +import org.apache.spark.SparkSaslServer + +import scala.collection.mutable.{HashMap, Queue, ArrayBuffer} + import java.net._ import java.nio._ import java.nio.channels._ @@ -27,13 +32,16 @@ import org.apache.spark._ private[spark] abstract class Connection(val channel: SocketChannel, val selector: Selector, - val socketRemoteConnectionManagerId: ConnectionManagerId) + val socketRemoteConnectionManagerId: ConnectionManagerId, val connectionId: ConnectionId) extends Logging { - def this(channel_ : SocketChannel, selector_ : Selector) = { + var sparkSaslServer: SparkSaslServer = null + var sparkSaslClient: SparkSaslClient = null + + def this(channel_ : SocketChannel, selector_ : Selector, id_ : ConnectionId) = { this(channel_, selector_, ConnectionManagerId.fromSocketAddress( - channel_.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress])) + channel_.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress]), id_) } channel.configureBlocking(false) @@ -49,6 +57,16 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, val remoteAddress = getRemoteAddress() + /** + * Used to synchronize client requests: client's work-related requests must + * wait until SASL authentication completes. + */ + private val authenticated = new Object() + + def getAuthenticated(): Object = authenticated + + def isSaslComplete(): Boolean + def resetForceReregister(): Boolean // Read channels typically do not register for write and write does not for read @@ -69,6 +87,16 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, // Will be true for ReceivingConnection, false for SendingConnection. def changeInterestForRead(): Boolean + private def disposeSasl() { + if (sparkSaslServer != null) { + sparkSaslServer.dispose(); + } + + if (sparkSaslClient != null) { + sparkSaslClient.dispose() + } + } + // On receiving a write event, should we change the interest for this channel or not ? // Will be false for ReceivingConnection, true for SendingConnection. // Actually, for now, should not get triggered for ReceivingConnection @@ -101,6 +129,7 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, k.cancel() } channel.close() + disposeSasl() callOnCloseCallback() } @@ -168,8 +197,12 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, private[spark] class SendingConnection(val address: InetSocketAddress, selector_ : Selector, - remoteId_ : ConnectionManagerId) - extends Connection(SocketChannel.open, selector_, remoteId_) { + remoteId_ : ConnectionManagerId, id_ : ConnectionId) + extends Connection(SocketChannel.open, selector_, remoteId_, id_) { + + def isSaslComplete(): Boolean = { + if (sparkSaslClient != null) sparkSaslClient.isComplete() else false + } private class Outbox { val messages = new Queue[Message]() @@ -226,6 +259,7 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector, data as detailed in https://github.com/mesos/spark/pull/791 */ private var needForceReregister = false + val currentBuffers = new ArrayBuffer[ByteBuffer]() /*channel.socket.setSendBufferSize(256 * 1024)*/ @@ -316,6 +350,7 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector, // If we have 'seen' pending messages, then reset flag - since we handle that as // normal registering of event (below) if (needForceReregister && buffers.exists(_.remaining() > 0)) resetForceReregister() + currentBuffers ++= buffers } case None => { @@ -384,8 +419,15 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector, // Must be created within selector loop - else deadlock -private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : Selector) - extends Connection(channel_, selector_) { +private[spark] class ReceivingConnection( + channel_ : SocketChannel, + selector_ : Selector, + id_ : ConnectionId) + extends Connection(channel_, selector_, id_) { + + def isSaslComplete(): Boolean = { + if (sparkSaslServer != null) sparkSaslServer.isComplete() else false + } class Inbox() { val messages = new HashMap[Int, BufferMessage]() @@ -396,6 +438,7 @@ private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : S val newMessage = Message.create(header).asInstanceOf[BufferMessage] newMessage.started = true newMessage.startTime = System.currentTimeMillis + newMessage.isSecurityNeg = header.securityNeg == 1 logDebug( "Starting to receive [" + newMessage + "] from [" + getRemoteConnectionManagerId() + "]") messages += ((newMessage.id, newMessage)) @@ -441,7 +484,7 @@ private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : S val inbox = new Inbox() val headerBuffer: ByteBuffer = ByteBuffer.allocate(MessageChunkHeader.HEADER_SIZE) - var onReceiveCallback: (Connection , Message) => Unit = null + var onReceiveCallback: (Connection, Message) => Unit = null var currentChunk: MessageChunk = null channel.register(selector, SelectionKey.OP_READ) @@ -516,7 +559,7 @@ private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : S } } } catch { - case e: Exception => { + case e: Exception => { logWarning("Error reading from connection to " + getRemoteConnectionManagerId(), e) callOnExceptionCallback(e) close() diff --git a/project/project/SparkPluginBuild.scala b/core/src/main/scala/org/apache/spark/network/ConnectionId.scala similarity index 55% rename from project/project/SparkPluginBuild.scala rename to core/src/main/scala/org/apache/spark/network/ConnectionId.scala index a88a5e14539ec..ffaab677d411a 100644 --- a/project/project/SparkPluginBuild.scala +++ b/core/src/main/scala/org/apache/spark/network/ConnectionId.scala @@ -15,12 +15,20 @@ * limitations under the License. */ -import sbt._ +package org.apache.spark.network -object SparkPluginDef extends Build { - lazy val root = Project("plugins", file(".")) dependsOn(junitXmlListener) - /* This is not published in a Maven repository, so we get it from GitHub directly */ - lazy val junitXmlListener = uri( - "https://github.com/chenkelmann/junit_xml_listener.git#3f8029fbfda54dc7a68b1afd2f885935e1090016" - ) +private[spark] case class ConnectionId(connectionManagerId: ConnectionManagerId, uniqId: Int) { + override def toString = connectionManagerId.host + "_" + connectionManagerId.port + "_" + uniqId +} + +private[spark] object ConnectionId { + + def createConnectionIdFromString(connectionIdString: String): ConnectionId = { + val res = connectionIdString.split("_").map(_.trim()) + if (res.size != 3) { + throw new Exception("Error converting ConnectionId string: " + connectionIdString + + " to a ConnectionId Object") + } + new ConnectionId(new ConnectionManagerId(res(0), res(1).toInt), res(2).toInt) + } } diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala index a7f20f8c51a5a..a75130cba2a2e 100644 --- a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala @@ -21,6 +21,9 @@ import java.net._ import java.nio._ import java.nio.channels._ import java.nio.channels.spi._ +import java.net._ +import java.util.concurrent.atomic.AtomicInteger + import java.util.concurrent.{LinkedBlockingDeque, TimeUnit, ThreadPoolExecutor} import scala.collection.mutable.ArrayBuffer @@ -28,13 +31,15 @@ import scala.collection.mutable.HashMap import scala.collection.mutable.HashSet import scala.collection.mutable.SynchronizedMap import scala.collection.mutable.SynchronizedQueue + import scala.concurrent.{Await, ExecutionContext, Future, Promise} import scala.concurrent.duration._ import org.apache.spark._ -import org.apache.spark.util.Utils +import org.apache.spark.util.{SystemClock, Utils} -private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Logging { +private[spark] class ConnectionManager(port: Int, conf: SparkConf, + securityManager: SecurityManager) extends Logging { class MessageStatus( val message: Message, @@ -50,6 +55,9 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi private val selector = SelectorProvider.provider.openSelector() + // default to 30 second timeout waiting for authentication + private val authTimeout = conf.getInt("spark.core.connection.auth.wait.timeout", 30) + private val handleMessageExecutor = new ThreadPoolExecutor( conf.getInt("spark.core.connection.handler.threads.min", 20), conf.getInt("spark.core.connection.handler.threads.max", 60), @@ -71,6 +79,9 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi new LinkedBlockingDeque[Runnable]()) private val serverChannel = ServerSocketChannel.open() + // used to track the SendingConnections waiting to do SASL negotiation + private val connectionsAwaitingSasl = new HashMap[ConnectionId, SendingConnection] + with SynchronizedMap[ConnectionId, SendingConnection] private val connectionsByKey = new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection] private val connectionsById = new HashMap[ConnectionManagerId, SendingConnection] @@ -84,6 +95,8 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi private var onReceiveCallback: (BufferMessage, ConnectionManagerId) => Option[Message]= null + private val authEnabled = securityManager.isAuthenticationEnabled() + serverChannel.configureBlocking(false) serverChannel.socket.setReuseAddress(true) serverChannel.socket.setReceiveBufferSize(256 * 1024) @@ -94,6 +107,10 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi val id = new ConnectionManagerId(Utils.localHostName, serverChannel.socket.getLocalPort) logInfo("Bound socket to port " + serverChannel.socket.getLocalPort() + " with id = " + id) + // used in combination with the ConnectionManagerId to create unique Connection ids + // to be able to track asynchronous messages + private val idCount: AtomicInteger = new AtomicInteger(1) + private val selectorThread = new Thread("connection-manager-thread") { override def run() = ConnectionManager.this.run() } @@ -125,7 +142,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi } finally { writeRunnableStarted.synchronized { writeRunnableStarted -= key - val needReregister = register || conn.resetForceReregister() + val needReregister = register || conn.resetForceReregister() if (needReregister && conn.changeInterestForWrite()) { conn.registerInterest() } @@ -372,7 +389,8 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi // accept them all in a tight loop. non blocking accept with no processing, should be fine while (newChannel != null) { try { - val newConnection = new ReceivingConnection(newChannel, selector) + val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue) + val newConnection = new ReceivingConnection(newChannel, selector, newConnectionId) newConnection.onReceive(receiveMessage) addListeners(newConnection) addConnection(newConnection) @@ -406,6 +424,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi logInfo("Removing SendingConnection to " + sendingConnectionManagerId) connectionsById -= sendingConnectionManagerId + connectionsAwaitingSasl -= connection.connectionId messageStatuses.synchronized { messageStatuses @@ -481,7 +500,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi val creationTime = System.currentTimeMillis def run() { logDebug("Handler thread delay is " + (System.currentTimeMillis - creationTime) + " ms") - handleMessage(connectionManagerId, message) + handleMessage(connectionManagerId, message, connection) logDebug("Handling delay is " + (System.currentTimeMillis - creationTime) + " ms") } } @@ -489,10 +508,133 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi /*handleMessage(connection, message)*/ } - private def handleMessage(connectionManagerId: ConnectionManagerId, message: Message) { + private def handleClientAuthentication( + waitingConn: SendingConnection, + securityMsg: SecurityMessage, + connectionId : ConnectionId) { + if (waitingConn.isSaslComplete()) { + logDebug("Client sasl completed for id: " + waitingConn.connectionId) + connectionsAwaitingSasl -= waitingConn.connectionId + waitingConn.getAuthenticated().synchronized { + waitingConn.getAuthenticated().notifyAll(); + } + return + } else { + var replyToken : Array[Byte] = null + try { + replyToken = waitingConn.sparkSaslClient.saslResponse(securityMsg.getToken); + if (waitingConn.isSaslComplete()) { + logDebug("Client sasl completed after evaluate for id: " + waitingConn.connectionId) + connectionsAwaitingSasl -= waitingConn.connectionId + waitingConn.getAuthenticated().synchronized { + waitingConn.getAuthenticated().notifyAll() + } + return + } + var securityMsgResp = SecurityMessage.fromResponse(replyToken, + securityMsg.getConnectionId.toString()) + var message = securityMsgResp.toBufferMessage + if (message == null) throw new Exception("Error creating security message") + sendSecurityMessage(waitingConn.getRemoteConnectionManagerId(), message) + } catch { + case e: Exception => { + logError("Error handling sasl client authentication", e) + waitingConn.close() + throw new Exception("Error evaluating sasl response: " + e) + } + } + } + } + + private def handleServerAuthentication( + connection: Connection, + securityMsg: SecurityMessage, + connectionId: ConnectionId) { + if (!connection.isSaslComplete()) { + logDebug("saslContext not established") + var replyToken : Array[Byte] = null + try { + connection.synchronized { + if (connection.sparkSaslServer == null) { + logDebug("Creating sasl Server") + connection.sparkSaslServer = new SparkSaslServer(securityManager) + } + } + replyToken = connection.sparkSaslServer.response(securityMsg.getToken) + if (connection.isSaslComplete()) { + logDebug("Server sasl completed: " + connection.connectionId) + } else { + logDebug("Server sasl not completed: " + connection.connectionId) + } + if (replyToken != null) { + var securityMsgResp = SecurityMessage.fromResponse(replyToken, + securityMsg.getConnectionId) + var message = securityMsgResp.toBufferMessage + if (message == null) throw new Exception("Error creating security Message") + sendSecurityMessage(connection.getRemoteConnectionManagerId(), message) + } + } catch { + case e: Exception => { + logError("Error in server auth negotiation: " + e) + // It would probably be better to send an error message telling other side auth failed + // but for now just close + connection.close() + } + } + } else { + logDebug("connection already established for this connection id: " + connection.connectionId) + } + } + + + private def handleAuthentication(conn: Connection, bufferMessage: BufferMessage): Boolean = { + if (bufferMessage.isSecurityNeg) { + logDebug("This is security neg message") + + // parse as SecurityMessage + val securityMsg = SecurityMessage.fromBufferMessage(bufferMessage) + val connectionId = ConnectionId.createConnectionIdFromString(securityMsg.getConnectionId) + + connectionsAwaitingSasl.get(connectionId) match { + case Some(waitingConn) => { + // Client - this must be in response to us doing Send + logDebug("Client handleAuth for id: " + waitingConn.connectionId) + handleClientAuthentication(waitingConn, securityMsg, connectionId) + } + case None => { + // Server - someone sent us something and we haven't authenticated yet + logDebug("Server handleAuth for id: " + connectionId) + handleServerAuthentication(conn, securityMsg, connectionId) + } + } + return true + } else { + if (!conn.isSaslComplete()) { + // We could handle this better and tell the client we need to do authentication + // negotiation, but for now just ignore them. + logError("message sent that is not security negotiation message on connection " + + "not authenticated yet, ignoring it!!") + return true + } + } + return false + } + + private def handleMessage( + connectionManagerId: ConnectionManagerId, + message: Message, + connection: Connection) { logDebug("Handling [" + message + "] from [" + connectionManagerId + "]") message match { case bufferMessage: BufferMessage => { + if (authEnabled) { + val res = handleAuthentication(connection, bufferMessage) + if (res == true) { + // message was security negotiation so skip the rest + logDebug("After handleAuth result was true, returning") + return + } + } if (bufferMessage.hasAckId) { val sentMessageStatus = messageStatuses.synchronized { messageStatuses.get(bufferMessage.ackId) match { @@ -541,17 +683,124 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi } } + private def checkSendAuthFirst(connManagerId: ConnectionManagerId, conn: SendingConnection) { + // see if we need to do sasl before writing + // this should only be the first negotiation as the Client!!! + if (!conn.isSaslComplete()) { + conn.synchronized { + if (conn.sparkSaslClient == null) { + conn.sparkSaslClient = new SparkSaslClient(securityManager) + var firstResponse: Array[Byte] = null + try { + firstResponse = conn.sparkSaslClient.firstToken() + var securityMsg = SecurityMessage.fromResponse(firstResponse, + conn.connectionId.toString()) + var message = securityMsg.toBufferMessage + if (message == null) throw new Exception("Error creating security message") + connectionsAwaitingSasl += ((conn.connectionId, conn)) + sendSecurityMessage(connManagerId, message) + logDebug("adding connectionsAwaitingSasl id: " + conn.connectionId) + } catch { + case e: Exception => { + logError("Error getting first response from the SaslClient.", e) + conn.close() + throw new Exception("Error getting first response from the SaslClient") + } + } + } + } + } else { + logDebug("Sasl already established ") + } + } + + // allow us to add messages to the inbox for doing sasl negotiating + private def sendSecurityMessage(connManagerId: ConnectionManagerId, message: Message) { + def startNewConnection(): SendingConnection = { + val inetSocketAddress = new InetSocketAddress(connManagerId.host, connManagerId.port) + val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue) + val newConnection = new SendingConnection(inetSocketAddress, selector, connManagerId, + newConnectionId) + logInfo("creating new sending connection for security! " + newConnectionId ) + registerRequests.enqueue(newConnection) + + newConnection + } + // I removed the lookupKey stuff as part of merge ... should I re-add it ? + // We did not find it useful in our test-env ... + // If we do re-add it, we should consistently use it everywhere I guess ? + message.senderAddress = id.toSocketAddress() + logTrace("Sending Security [" + message + "] to [" + connManagerId + "]") + val connection = connectionsById.getOrElseUpdate(connManagerId, startNewConnection()) + + //send security message until going connection has been authenticated + connection.send(message) + + wakeupSelector() + } + private def sendMessage(connectionManagerId: ConnectionManagerId, message: Message) { def startNewConnection(): SendingConnection = { val inetSocketAddress = new InetSocketAddress(connectionManagerId.host, connectionManagerId.port) - val newConnection = new SendingConnection(inetSocketAddress, selector, connectionManagerId) + val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue) + val newConnection = new SendingConnection(inetSocketAddress, selector, connectionManagerId, + newConnectionId) + logTrace("creating new sending connection: " + newConnectionId) registerRequests.enqueue(newConnection) newConnection } val connection = connectionsById.getOrElseUpdate(connectionManagerId, startNewConnection()) + if (authEnabled) { + checkSendAuthFirst(connectionManagerId, connection) + } message.senderAddress = id.toSocketAddress() + logDebug("Before Sending [" + message + "] to [" + connectionManagerId + "]" + " " + + "connectionid: " + connection.connectionId) + + if (authEnabled) { + // if we aren't authenticated yet lets block the senders until authentication completes + try { + connection.getAuthenticated().synchronized { + val clock = SystemClock + val startTime = clock.getTime() + + while (!connection.isSaslComplete()) { + logDebug("getAuthenticated wait connectionid: " + connection.connectionId) + // have timeout in case remote side never responds + connection.getAuthenticated().wait(500) + if (((clock.getTime() - startTime) >= (authTimeout * 1000)) + && (!connection.isSaslComplete())) { + // took to long to authenticate the connection, something probably went wrong + throw new Exception("Took to long for authentication to " + connectionManagerId + + ", waited " + authTimeout + "seconds, failing.") + } + } + } + } catch { + case e: Exception => logError("Exception while waiting for authentication.", e) + + // need to tell sender it failed + messageStatuses.synchronized { + val s = messageStatuses.get(message.id) + s match { + case Some(msgStatus) => { + messageStatuses -= message.id + logInfo("Notifying " + msgStatus.connectionManagerId) + msgStatus.synchronized { + msgStatus.attempted = true + msgStatus.acked = false + msgStatus.markDone() + } + } + case None => { + logError("no messageStatus for failed message id: " + message.id) + } + } + } + } + } logDebug("Sending [" + message + "] to [" + connectionManagerId + "]") connection.send(message) @@ -603,7 +852,8 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi private[spark] object ConnectionManager { def main(args: Array[String]) { - val manager = new ConnectionManager(9999, new SparkConf) + val conf = new SparkConf + val manager = new ConnectionManager(9999, conf, new SecurityManager(conf)) manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { println("Received [" + msg + "] from [" + id + "]") None diff --git a/core/src/main/scala/org/apache/spark/network/Message.scala b/core/src/main/scala/org/apache/spark/network/Message.scala index 20fe67661844f..7caccfdbb44f9 100644 --- a/core/src/main/scala/org/apache/spark/network/Message.scala +++ b/core/src/main/scala/org/apache/spark/network/Message.scala @@ -27,6 +27,7 @@ private[spark] abstract class Message(val typ: Long, val id: Int) { var started = false var startTime = -1L var finishTime = -1L + var isSecurityNeg = false def size: Int diff --git a/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala b/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala index 9bcbc6141a502..ead663ede7a1c 100644 --- a/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala +++ b/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala @@ -27,6 +27,7 @@ private[spark] class MessageChunkHeader( val totalSize: Int, val chunkSize: Int, val other: Int, + val securityNeg: Int, val address: InetSocketAddress) { lazy val buffer = { // No need to change this, at 'use' time, we do a reverse lookup of the hostname. @@ -40,6 +41,7 @@ private[spark] class MessageChunkHeader( putInt(totalSize). putInt(chunkSize). putInt(other). + putInt(securityNeg). putInt(ip.size). put(ip). putInt(port). @@ -48,12 +50,13 @@ private[spark] class MessageChunkHeader( } override def toString = "" + this.getClass.getSimpleName + ":" + id + " of type " + typ + - " and sizes " + totalSize + " / " + chunkSize + " bytes" + " and sizes " + totalSize + " / " + chunkSize + " bytes, securityNeg: " + securityNeg + } private[spark] object MessageChunkHeader { - val HEADER_SIZE = 40 + val HEADER_SIZE = 44 def create(buffer: ByteBuffer): MessageChunkHeader = { if (buffer.remaining != HEADER_SIZE) { @@ -64,11 +67,13 @@ private[spark] object MessageChunkHeader { val totalSize = buffer.getInt() val chunkSize = buffer.getInt() val other = buffer.getInt() + val securityNeg = buffer.getInt() val ipSize = buffer.getInt() val ipBytes = new Array[Byte](ipSize) buffer.get(ipBytes) val ip = InetAddress.getByAddress(ipBytes) val port = buffer.getInt() - new MessageChunkHeader(typ, id, totalSize, chunkSize, other, new InetSocketAddress(ip, port)) + new MessageChunkHeader(typ, id, totalSize, chunkSize, other, securityNeg, + new InetSocketAddress(ip, port)) } } diff --git a/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala b/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala index 9976255c7e251..3c09a713c6fe0 100644 --- a/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala +++ b/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala @@ -18,12 +18,12 @@ package org.apache.spark.network import java.nio.ByteBuffer - -import org.apache.spark.SparkConf +import org.apache.spark.{SecurityManager, SparkConf} private[spark] object ReceiverTest { def main(args: Array[String]) { - val manager = new ConnectionManager(9999, new SparkConf) + val conf = new SparkConf + val manager = new ConnectionManager(9999, conf, new SecurityManager(conf)) println("Started connection manager with id = " + manager.id) manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { diff --git a/core/src/main/scala/org/apache/spark/network/SecurityMessage.scala b/core/src/main/scala/org/apache/spark/network/SecurityMessage.scala new file mode 100644 index 0000000000000..0d9f743b3624b --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/SecurityMessage.scala @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network + +import java.nio.ByteBuffer + +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.StringBuilder + +import org.apache.spark._ +import org.apache.spark.network._ + +/** + * SecurityMessage is class that contains the connectionId and sasl token + * used in SASL negotiation. SecurityMessage has routines for converting + * it to and from a BufferMessage so that it can be sent by the ConnectionManager + * and easily consumed by users when received. + * The api was modeled after BlockMessage. + * + * The connectionId is the connectionId of the client side. Since + * message passing is asynchronous and its possible for the server side (receiving) + * to get multiple different types of messages on the same connection the connectionId + * is used to know which connnection the security message is intended for. + * + * For instance, lets say we are node_0. We need to send data to node_1. The node_0 side + * is acting as a client and connecting to node_1. SASL negotiation has to occur + * between node_0 and node_1 before node_1 trusts node_0 so node_0 sends a security message. + * node_1 receives the message from node_0 but before it can process it and send a response, + * some thread on node_1 decides it needs to send data to node_0 so it connects to node_0 + * and sends a security message of its own to authenticate as a client. Now node_0 gets + * the message and it needs to decide if this message is in response to it being a client + * (from the first send) or if its just node_1 trying to connect to it to send data. This + * is where the connectionId field is used. node_0 can lookup the connectionId to see if + * it is in response to it being a client or if its in response to someone sending other data. + * + * The format of a SecurityMessage as its sent is: + * - Length of the ConnectionId + * - ConnectionId + * - Length of the token + * - Token + */ +private[spark] class SecurityMessage() extends Logging { + + private var connectionId: String = null + private var token: Array[Byte] = null + + def set(byteArr: Array[Byte], newconnectionId: String) { + if (byteArr == null) { + token = new Array[Byte](0) + } else { + token = byteArr + } + connectionId = newconnectionId + } + + /** + * Read the given buffer and set the members of this class. + */ + def set(buffer: ByteBuffer) { + val idLength = buffer.getInt() + val idBuilder = new StringBuilder(idLength) + for (i <- 1 to idLength) { + idBuilder += buffer.getChar() + } + connectionId = idBuilder.toString() + + val tokenLength = buffer.getInt() + token = new Array[Byte](tokenLength) + if (tokenLength > 0) { + buffer.get(token, 0, tokenLength) + } + } + + def set(bufferMsg: BufferMessage) { + val buffer = bufferMsg.buffers.apply(0) + buffer.clear() + set(buffer) + } + + def getConnectionId: String = { + return connectionId + } + + def getToken: Array[Byte] = { + return token + } + + /** + * Create a BufferMessage that can be sent by the ConnectionManager containing + * the security information from this class. + * @return BufferMessage + */ + def toBufferMessage: BufferMessage = { + val startTime = System.currentTimeMillis + val buffers = new ArrayBuffer[ByteBuffer]() + + // 4 bytes for the length of the connectionId + // connectionId is of type char so multiple the length by 2 to get number of bytes + // 4 bytes for the length of token + // token is a byte buffer so just take the length + var buffer = ByteBuffer.allocate(4 + connectionId.length() * 2 + 4 + token.length) + buffer.putInt(connectionId.length()) + connectionId.foreach((x: Char) => buffer.putChar(x)) + buffer.putInt(token.length) + + if (token.length > 0) { + buffer.put(token) + } + buffer.flip() + buffers += buffer + + var message = Message.createBufferMessage(buffers) + logDebug("message total size is : " + message.size) + message.isSecurityNeg = true + return message + } + + override def toString: String = { + "SecurityMessage [connId= " + connectionId + ", Token = " + token + "]" + } +} + +private[spark] object SecurityMessage { + + /** + * Convert the given BufferMessage to a SecurityMessage by parsing the contents + * of the BufferMessage and populating the SecurityMessage fields. + * @param bufferMessage is a BufferMessage that was received + * @return new SecurityMessage + */ + def fromBufferMessage(bufferMessage: BufferMessage): SecurityMessage = { + val newSecurityMessage = new SecurityMessage() + newSecurityMessage.set(bufferMessage) + newSecurityMessage + } + + /** + * Create a SecurityMessage to send from a given saslResponse. + * @param response is the response to a challenge from the SaslClient or Saslserver + * @param connectionId the client connectionId we are negotiation authentication for + * @return a new SecurityMessage + */ + def fromResponse(response : Array[Byte], connectionId : String) : SecurityMessage = { + val newSecurityMessage = new SecurityMessage() + newSecurityMessage.set(response, connectionId) + newSecurityMessage + } +} diff --git a/core/src/main/scala/org/apache/spark/network/SenderTest.scala b/core/src/main/scala/org/apache/spark/network/SenderTest.scala index 646f8425d9551..aac2c24a46faa 100644 --- a/core/src/main/scala/org/apache/spark/network/SenderTest.scala +++ b/core/src/main/scala/org/apache/spark/network/SenderTest.scala @@ -18,8 +18,7 @@ package org.apache.spark.network import java.nio.ByteBuffer - -import org.apache.spark.SparkConf +import org.apache.spark.{SecurityManager, SparkConf} private[spark] object SenderTest { def main(args: Array[String]) { @@ -32,8 +31,8 @@ private[spark] object SenderTest { val targetHost = args(0) val targetPort = args(1).toInt val targetConnectionManagerId = new ConnectionManagerId(targetHost, targetPort) - - val manager = new ConnectionManager(0, new SparkConf) + val conf = new SparkConf + val manager = new ConnectionManager(0, conf, new SecurityManager(conf)) println("Started connection manager with id = " + manager.id) manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index a374fc4a871b0..100ddb360732a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -18,8 +18,10 @@ package org.apache.spark.rdd import java.io.EOFException +import scala.collection.immutable.Map import org.apache.hadoop.conf.{Configurable, Configuration} +import org.apache.hadoop.mapred.FileSplit import org.apache.hadoop.mapred.InputFormat import org.apache.hadoop.mapred.InputSplit import org.apache.hadoop.mapred.JobConf @@ -43,6 +45,23 @@ private[spark] class HadoopPartition(rddId: Int, idx: Int, @transient s: InputSp override def hashCode(): Int = 41 * (41 + rddId) + idx override val index: Int = idx + + /** + * Get any environment variables that should be added to the users environment when running pipes + * @return a Map with the environment variables and corresponding values, it could be empty + */ + def getPipeEnvVars(): Map[String, String] = { + val envVars: Map[String, String] = if (inputSplit.value.isInstanceOf[FileSplit]) { + val is: FileSplit = inputSplit.value.asInstanceOf[FileSplit] + // map_input_file is deprecated in favor of mapreduce_map_input_file but set both + // since its not removed yet + Map("map_input_file" -> is.getPath().toString(), + "mapreduce_map_input_file" -> is.getPath().toString()) + } else { + Map() + } + envVars + } } /** diff --git a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala index abd4414e81f5c..4250a9d02f764 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala @@ -28,6 +28,7 @@ import scala.reflect.ClassTag import org.apache.spark.{Partition, SparkEnv, TaskContext} + /** * An RDD that pipes the contents of each parent partition through an external command * (printing them one per line) and returns the output as a collection of strings. @@ -59,6 +60,13 @@ class PipedRDD[T: ClassTag]( val currentEnvVars = pb.environment() envVars.foreach { case (variable, value) => currentEnvVars.put(variable, value) } + // for compatibility with Hadoop which sets these env variables + // so the user code can access the input filename + if (split.isInstanceOf[HadoopPartition]) { + val hadoopSplit = split.asInstanceOf[HadoopPartition] + currentEnvVars.putAll(hadoopSplit.getPipeEnvVars()) + } + val proc = pb.start() val env = SparkEnv.get diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala index 33c1705ad7c58..bfa647f7f0516 100644 --- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala @@ -23,9 +23,28 @@ import java.nio.ByteBuffer import org.apache.spark.SparkConf import org.apache.spark.util.ByteBufferInputStream -private[spark] class JavaSerializationStream(out: OutputStream) extends SerializationStream { +private[spark] class JavaSerializationStream(out: OutputStream, conf: SparkConf) + extends SerializationStream { val objOut = new ObjectOutputStream(out) - def writeObject[T](t: T): SerializationStream = { objOut.writeObject(t); this } + var counter = 0 + val counterReset = conf.getInt("spark.serializer.objectStreamReset", 10000) + + /** + * Calling reset to avoid memory leak: + * http://stackoverflow.com/questions/1281549/memory-leak-traps-in-the-java-standard-api + * But only call it every 10,000th time to avoid bloated serialization streams (when + * the stream 'resets' object class descriptions have to be re-written) + */ + def writeObject[T](t: T): SerializationStream = { + objOut.writeObject(t) + if (counterReset > 0 && counter >= counterReset) { + objOut.reset() + counter = 0 + } else { + counter += 1 + } + this + } def flush() { objOut.flush() } def close() { objOut.close() } } @@ -41,7 +60,7 @@ extends DeserializationStream { def close() { objIn.close() } } -private[spark] class JavaSerializerInstance extends SerializerInstance { +private[spark] class JavaSerializerInstance(conf: SparkConf) extends SerializerInstance { def serialize[T](t: T): ByteBuffer = { val bos = new ByteArrayOutputStream() val out = serializeStream(bos) @@ -63,7 +82,7 @@ private[spark] class JavaSerializerInstance extends SerializerInstance { } def serializeStream(s: OutputStream): SerializationStream = { - new JavaSerializationStream(s) + new JavaSerializationStream(s, conf) } def deserializeStream(s: InputStream): DeserializationStream = { @@ -79,5 +98,5 @@ private[spark] class JavaSerializerInstance extends SerializerInstance { * A Spark serializer that uses Java's built-in serialization. */ class JavaSerializer(conf: SparkConf) extends Serializer { - def newInstance(): SerializerInstance = new JavaSerializerInstance + def newInstance(): SerializerInstance = new JavaSerializerInstance(conf) } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index a734ddc1ef702..1bf3f4db32ea7 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -29,19 +29,26 @@ import akka.actor.{ActorSystem, Cancellable, Props} import it.unimi.dsi.fastutil.io.{FastBufferedOutputStream, FastByteArrayOutputStream} import sun.nio.ch.DirectBuffer -import org.apache.spark.{Logging, SparkConf, SparkEnv, SparkException} +import org.apache.spark.{Logging, SparkConf, SparkEnv, SparkException, SecurityManager} import org.apache.spark.io.CompressionCodec import org.apache.spark.network._ import org.apache.spark.serializer.Serializer import org.apache.spark.util._ +sealed trait Values + +case class ByteBufferValues(buffer: ByteBuffer) extends Values +case class IteratorValues(iterator: Iterator[Any]) extends Values +case class ArrayBufferValues(buffer: ArrayBuffer[Any]) extends Values + private[spark] class BlockManager( executorId: String, actorSystem: ActorSystem, val master: BlockManagerMaster, val defaultSerializer: Serializer, maxMemory: Long, - val conf: SparkConf) + val conf: SparkConf, + securityManager: SecurityManager) extends Logging { val shuffleBlockManager = new ShuffleBlockManager(this) @@ -60,7 +67,7 @@ private[spark] class BlockManager( if (useNetty) diskBlockManager.startShuffleBlockSender(nettyPortConfig) else 0 } - val connectionManager = new ConnectionManager(0, conf) + val connectionManager = new ConnectionManager(0, conf, securityManager) implicit val futureExecContext = connectionManager.futureExecContext val blockManagerId = BlockManagerId( @@ -116,8 +123,9 @@ private[spark] class BlockManager( * Construct a BlockManager with a memory limit set based on system properties. */ def this(execId: String, actorSystem: ActorSystem, master: BlockManagerMaster, - serializer: Serializer, conf: SparkConf) = { - this(execId, actorSystem, master, serializer, BlockManager.getMaxMemory(conf), conf) + serializer: Serializer, conf: SparkConf, securityManager: SecurityManager) = { + this(execId, actorSystem, master, serializer, BlockManager.getMaxMemory(conf), conf, + securityManager) } /** @@ -455,9 +463,7 @@ private[spark] class BlockManager( def put(blockId: BlockId, values: Iterator[Any], level: StorageLevel, tellMaster: Boolean) : Long = { - val elements = new ArrayBuffer[Any] - elements ++= values - put(blockId, elements, level, tellMaster) + doPut(blockId, IteratorValues(values), level, tellMaster) } /** @@ -479,7 +485,7 @@ private[spark] class BlockManager( def put(blockId: BlockId, values: ArrayBuffer[Any], level: StorageLevel, tellMaster: Boolean = true) : Long = { require(values != null, "Values is null") - doPut(blockId, Left(values), level, tellMaster) + doPut(blockId, ArrayBufferValues(values), level, tellMaster) } /** @@ -488,10 +494,11 @@ private[spark] class BlockManager( def putBytes(blockId: BlockId, bytes: ByteBuffer, level: StorageLevel, tellMaster: Boolean = true) { require(bytes != null, "Bytes is null") - doPut(blockId, Right(bytes), level, tellMaster) + doPut(blockId, ByteBufferValues(bytes), level, tellMaster) } - private def doPut(blockId: BlockId, data: Either[ArrayBuffer[Any], ByteBuffer], + private def doPut(blockId: BlockId, + data: Values, level: StorageLevel, tellMaster: Boolean = true): Long = { require(blockId != null, "BlockId is null") require(level != null && level.isValid, "StorageLevel is null or invalid") @@ -534,8 +541,9 @@ private[spark] class BlockManager( // If we're storing bytes, then initiate the replication before storing them locally. // This is faster as data is already serialized and ready to send. - val replicationFuture = if (data.isRight && level.replication > 1) { - val bufferView = data.right.get.duplicate() // Doesn't copy the bytes, just creates a wrapper + val replicationFuture = if (data.isInstanceOf[ByteBufferValues] && level.replication > 1) { + // Duplicate doesn't copy the bytes, just creates a wrapper + val bufferView = data.asInstanceOf[ByteBufferValues].buffer.duplicate() Future { replicate(blockId, bufferView, level) } @@ -549,34 +557,43 @@ private[spark] class BlockManager( var marked = false try { - data match { - case Left(values) => { - if (level.useMemory) { - // Save it just to memory first, even if it also has useDisk set to true; we will - // drop it to disk later if the memory store can't hold it. - val res = memoryStore.putValues(blockId, values, level, true) - size = res.size - res.data match { - case Right(newBytes) => bytesAfterPut = newBytes - case Left(newIterator) => valuesAfterPut = newIterator - } - } else { - // Save directly to disk. - // Don't get back the bytes unless we replicate them. - val askForBytes = level.replication > 1 - val res = diskStore.putValues(blockId, values, level, askForBytes) - size = res.size - res.data match { - case Right(newBytes) => bytesAfterPut = newBytes - case _ => - } + if (level.useMemory) { + // Save it just to memory first, even if it also has useDisk set to true; we will + // drop it to disk later if the memory store can't hold it. + val res = data match { + case IteratorValues(iterator) => + memoryStore.putValues(blockId, iterator, level, true) + case ArrayBufferValues(array) => + memoryStore.putValues(blockId, array, level, true) + case ByteBufferValues(bytes) => { + bytes.rewind(); + memoryStore.putBytes(blockId, bytes, level) + } + } + size = res.size + res.data match { + case Right(newBytes) => bytesAfterPut = newBytes + case Left(newIterator) => valuesAfterPut = newIterator + } + } else { + // Save directly to disk. + // Don't get back the bytes unless we replicate them. + val askForBytes = level.replication > 1 + + val res = data match { + case IteratorValues(iterator) => + diskStore.putValues(blockId, iterator, level, askForBytes) + case ArrayBufferValues(array) => + diskStore.putValues(blockId, array, level, askForBytes) + case ByteBufferValues(bytes) => { + bytes.rewind(); + diskStore.putBytes(blockId, bytes, level) } } - case Right(bytes) => { - bytes.rewind() - // Store it only in memory at first, even if useDisk is also set to true - (if (level.useMemory) memoryStore else diskStore).putBytes(blockId, bytes, level) - size = bytes.limit + size = res.size + res.data match { + case Right(newBytes) => bytesAfterPut = newBytes + case _ => } } @@ -605,8 +622,8 @@ private[spark] class BlockManager( // values and need to serialize and replicate them now: if (level.replication > 1) { data match { - case Right(bytes) => Await.ready(replicationFuture, Duration.Inf) - case Left(values) => { + case ByteBufferValues(bytes) => Await.ready(replicationFuture, Duration.Inf) + case _ => { val remoteStartTime = System.currentTimeMillis // Serialize the block if not already done if (bytesAfterPut == null) { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockStore.scala b/core/src/main/scala/org/apache/spark/storage/BlockStore.scala index b047644b88f48..9a9be047c7245 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockStore.scala @@ -28,7 +28,7 @@ import org.apache.spark.Logging */ private[spark] abstract class BlockStore(val blockManager: BlockManager) extends Logging { - def putBytes(blockId: BlockId, bytes: ByteBuffer, level: StorageLevel) + def putBytes(blockId: BlockId, bytes: ByteBuffer, level: StorageLevel) : PutResult /** * Put in a block and, possibly, also return its content as either bytes or another Iterator. @@ -37,6 +37,9 @@ abstract class BlockStore(val blockManager: BlockManager) extends Logging { * @return a PutResult that contains the size of the data, as well as the values put if * returnValues is true (if not, the result's data field can be null) */ + def putValues(blockId: BlockId, values: Iterator[Any], level: StorageLevel, + returnValues: Boolean) : PutResult + def putValues(blockId: BlockId, values: ArrayBuffer[Any], level: StorageLevel, returnValues: Boolean) : PutResult diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala index d1f07ddb24bb2..36ee4bcc41c66 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala @@ -37,7 +37,7 @@ private class DiskStore(blockManager: BlockManager, diskManager: DiskBlockManage diskManager.getBlockLocation(blockId).length } - override def putBytes(blockId: BlockId, _bytes: ByteBuffer, level: StorageLevel) { + override def putBytes(blockId: BlockId, _bytes: ByteBuffer, level: StorageLevel) : PutResult = { // So that we do not modify the input offsets ! // duplicate does not copy buffer, so inexpensive val bytes = _bytes.duplicate() @@ -52,6 +52,7 @@ private class DiskStore(blockManager: BlockManager, diskManager: DiskBlockManage val finishTime = System.currentTimeMillis logDebug("Block %s stored as %s file on disk in %d ms".format( file.getName, Utils.bytesToString(bytes.limit), (finishTime - startTime))) + return PutResult(bytes.limit(), Right(bytes.duplicate())) } override def putValues( @@ -59,13 +60,22 @@ private class DiskStore(blockManager: BlockManager, diskManager: DiskBlockManage values: ArrayBuffer[Any], level: StorageLevel, returnValues: Boolean) + : PutResult = { + return putValues(blockId, values.toIterator, level, returnValues) + } + + override def putValues( + blockId: BlockId, + values: Iterator[Any], + level: StorageLevel, + returnValues: Boolean) : PutResult = { logDebug("Attempting to write values for block " + blockId) val startTime = System.currentTimeMillis val file = diskManager.getFile(blockId) val outputStream = new FileOutputStream(file) - blockManager.dataSerializeStream(blockId, outputStream, values.iterator) + blockManager.dataSerializeStream(blockId, outputStream, values) val length = file.length val timeTaken = System.currentTimeMillis - startTime diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala index 18141756518c5..38836d44b04e8 100644 --- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala @@ -49,7 +49,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) } } - override def putBytes(blockId: BlockId, _bytes: ByteBuffer, level: StorageLevel) { + override def putBytes(blockId: BlockId, _bytes: ByteBuffer, level: StorageLevel) : PutResult = { // Work on a duplicate - since the original input might be used elsewhere. val bytes = _bytes.duplicate() bytes.rewind() @@ -59,8 +59,10 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) elements ++= values val sizeEstimate = SizeEstimator.estimate(elements.asInstanceOf[AnyRef]) tryToPut(blockId, elements, sizeEstimate, true) + PutResult(sizeEstimate, Left(values.toIterator)) } else { tryToPut(blockId, bytes, bytes.limit, false) + PutResult(bytes.limit(), Right(bytes.duplicate())) } } @@ -69,14 +71,33 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) values: ArrayBuffer[Any], level: StorageLevel, returnValues: Boolean) - : PutResult = { - + : PutResult = { if (level.deserialized) { val sizeEstimate = SizeEstimator.estimate(values.asInstanceOf[AnyRef]) tryToPut(blockId, values, sizeEstimate, true) - PutResult(sizeEstimate, Left(values.iterator)) + PutResult(sizeEstimate, Left(values.toIterator)) + } else { + val bytes = blockManager.dataSerialize(blockId, values.toIterator) + tryToPut(blockId, bytes, bytes.limit, false) + PutResult(bytes.limit(), Right(bytes.duplicate())) + } + } + + override def putValues( + blockId: BlockId, + values: Iterator[Any], + level: StorageLevel, + returnValues: Boolean) + : PutResult = { + + if (level.deserialized) { + val valueEntries = new ArrayBuffer[Any]() + valueEntries ++= values + val sizeEstimate = SizeEstimator.estimate(valueEntries.asInstanceOf[AnyRef]) + tryToPut(blockId, valueEntries, sizeEstimate, true) + PutResult(sizeEstimate, Left(valueEntries.toIterator)) } else { - val bytes = blockManager.dataSerialize(blockId, values.iterator) + val bytes = blockManager.dataSerialize(blockId, values) tryToPut(blockId, bytes, bytes.limit, false) PutResult(bytes.limit(), Right(bytes.duplicate())) } @@ -215,13 +236,10 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) while (maxMemory - (currentMemory - selectedMemory) < space && iterator.hasNext) { val pair = iterator.next() val blockId = pair.getKey - if (rddToAdd.isDefined && rddToAdd == getRddId(blockId)) { - logInfo("Will not store " + blockIdToAdd + " as it would require dropping another " + - "block from the same RDD") - return false + if (rddToAdd.isEmpty || rddToAdd != getRddId(blockId)) { + selectedBlocks += blockId + selectedMemory += pair.getValue.size } - selectedBlocks += blockId - selectedMemory += pair.getValue.size } } @@ -243,6 +261,8 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) } return true } else { + logInfo(s"Will not store $blockIdToAdd as it would require dropping another block " + + "from the same RDD") return false } } diff --git a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala index 1d81d006c0b29..36f2a0fd02724 100644 --- a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala +++ b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala @@ -24,6 +24,7 @@ import util.Random import org.apache.spark.SparkConf import org.apache.spark.serializer.KryoSerializer +import org.apache.spark.{SecurityManager, SparkConf} /** * This class tests the BlockManager and MemoryStore for thread safety and @@ -98,7 +99,8 @@ private[spark] object ThreadingTest { val blockManagerMaster = new BlockManagerMaster( actorSystem.actorOf(Props(new BlockManagerMasterActor(true, conf))), conf) val blockManager = new BlockManager( - "", actorSystem, blockManagerMaster, serializer, 1024 * 1024, conf) + "", actorSystem, blockManagerMaster, serializer, 1024 * 1024, conf, + new SecurityManager(conf)) val producers = (1 to numProducers).map(i => new ProducerThread(blockManager, i)) val consumers = producers.map(p => new ConsumerThread(blockManager, p.queue)) producers.foreach(_.start) diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index 1b78c52ff6077..7c35cd165ad7c 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -18,7 +18,8 @@ package org.apache.spark.ui import java.net.InetSocketAddress -import javax.servlet.http.{HttpServletResponse, HttpServletRequest} +import java.net.URL +import javax.servlet.http.{HttpServlet, HttpServletResponse, HttpServletRequest} import scala.annotation.tailrec import scala.util.{Failure, Success, Try} @@ -26,11 +27,14 @@ import scala.xml.Node import org.json4s.JValue import org.json4s.jackson.JsonMethods.{pretty, render} -import org.eclipse.jetty.server.{Handler, Request, Server} -import org.eclipse.jetty.server.handler.{AbstractHandler, ContextHandler, HandlerList, ResourceHandler} + +import org.eclipse.jetty.server.{DispatcherType, Server} +import org.eclipse.jetty.server.handler.HandlerList +import org.eclipse.jetty.servlet.{DefaultServlet, FilterHolder, ServletContextHandler, ServletHolder} import org.eclipse.jetty.util.thread.QueuedThreadPool -import org.apache.spark.Logging +import org.apache.spark.{Logging, SecurityManager, SparkConf} + /** Utilities for launching a web server using Jetty's HTTP Server class */ private[spark] object JettyUtils extends Logging { @@ -39,57 +43,104 @@ private[spark] object JettyUtils extends Logging { type Responder[T] = HttpServletRequest => T - // Conversions from various types of Responder's to jetty Handlers - implicit def jsonResponderToHandler(responder: Responder[JValue]): Handler = - createHandler(responder, "text/json", (in: JValue) => pretty(render(in))) + class ServletParams[T <% AnyRef](val responder: Responder[T], + val contentType: String, + val extractFn: T => String = (in: Any) => in.toString) {} + + // Conversions from various types of Responder's to appropriate servlet parameters + implicit def jsonResponderToServlet(responder: Responder[JValue]): ServletParams[JValue] = + new ServletParams(responder, "text/json", (in: JValue) => pretty(render(in))) - implicit def htmlResponderToHandler(responder: Responder[Seq[Node]]): Handler = - createHandler(responder, "text/html", (in: Seq[Node]) => "" + in.toString) + implicit def htmlResponderToServlet(responder: Responder[Seq[Node]]): ServletParams[Seq[Node]] = + new ServletParams(responder, "text/html", (in: Seq[Node]) => "" + in.toString) - implicit def textResponderToHandler(responder: Responder[String]): Handler = - createHandler(responder, "text/plain") + implicit def textResponderToServlet(responder: Responder[String]): ServletParams[String] = + new ServletParams(responder, "text/plain") - def createHandler[T <% AnyRef](responder: Responder[T], contentType: String, - extractFn: T => String = (in: Any) => in.toString): Handler = { - new AbstractHandler { - def handle(target: String, - baseRequest: Request, - request: HttpServletRequest, + def createServlet[T <% AnyRef](servletParams: ServletParams[T], + securityMgr: SecurityManager): HttpServlet = { + new HttpServlet { + override def doGet(request: HttpServletRequest, response: HttpServletResponse) { - response.setContentType("%s;charset=utf-8".format(contentType)) - response.setStatus(HttpServletResponse.SC_OK) - baseRequest.setHandled(true) - val result = responder(request) - response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate") - response.getWriter().println(extractFn(result)) + if (securityMgr.checkUIViewPermissions(request.getRemoteUser())) { + response.setContentType("%s;charset=utf-8".format(servletParams.contentType)) + response.setStatus(HttpServletResponse.SC_OK) + val result = servletParams.responder(request) + response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate") + response.getWriter().println(servletParams.extractFn(result)) + } else { + response.setStatus(HttpServletResponse.SC_UNAUTHORIZED) + response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate") + response.sendError(HttpServletResponse.SC_UNAUTHORIZED, + "User is not authorized to access this page."); + } } } } + def createServletHandler(path: String, servlet: HttpServlet): ServletContextHandler = { + val contextHandler = new ServletContextHandler() + val holder = new ServletHolder(servlet) + contextHandler.setContextPath(path) + contextHandler.addServlet(holder, "/") + contextHandler + } + /** Creates a handler that always redirects the user to a given path */ - def createRedirectHandler(newPath: String): Handler = { - new AbstractHandler { - def handle(target: String, - baseRequest: Request, - request: HttpServletRequest, + def createRedirectHandler(newPath: String, path: String): ServletContextHandler = { + val servlet = new HttpServlet { + override def doGet(request: HttpServletRequest, response: HttpServletResponse) { - response.setStatus(302) - response.setHeader("Location", baseRequest.getRootURL + newPath) - baseRequest.setHandled(true) + // make sure we don't end up with // in the middle + val newUri = new URL(new URL(request.getRequestURL.toString), newPath).toURI + response.sendRedirect(newUri.toString) } } + val contextHandler = new ServletContextHandler() + val holder = new ServletHolder(servlet) + contextHandler.setContextPath(path) + contextHandler.addServlet(holder, "/") + contextHandler } /** Creates a handler for serving files from a static directory */ - def createStaticHandler(resourceBase: String): ResourceHandler = { - val staticHandler = new ResourceHandler + def createStaticHandler(resourceBase: String, path: String): ServletContextHandler = { + val contextHandler = new ServletContextHandler() + val staticHandler = new DefaultServlet + val holder = new ServletHolder(staticHandler) Option(getClass.getClassLoader.getResource(resourceBase)) match { case Some(res) => - staticHandler.setResourceBase(res.toString) + holder.setInitParameter("resourceBase", res.toString) case None => throw new Exception("Could not find resource path for Web UI: " + resourceBase) } - staticHandler + contextHandler.addServlet(holder, path) + contextHandler + } + + private def addFilters(handlers: Seq[ServletContextHandler], conf: SparkConf) { + val filters: Array[String] = conf.get("spark.ui.filters", "").split(',').map(_.trim()) + filters.foreach { + case filter : String => + if (!filter.isEmpty) { + logInfo("Adding filter: " + filter) + val holder : FilterHolder = new FilterHolder() + holder.setClassName(filter) + // get any parameters for each filter + val paramName = "spark." + filter + ".params" + val params = conf.get(paramName, "").split(',').map(_.trim()).toSet + params.foreach { + case param : String => + if (!param.isEmpty) { + val parts = param.split("=") + if (parts.length == 2) holder.setInitParameter(parts(0), parts(1)) + } + } + val enumDispatcher = java.util.EnumSet.of(DispatcherType.ASYNC, DispatcherType.ERROR, + DispatcherType.FORWARD, DispatcherType.INCLUDE, DispatcherType.REQUEST) + handlers.foreach { case(handler) => handler.addFilter(holder, "/*", enumDispatcher) } + } + } } /** @@ -99,17 +150,12 @@ private[spark] object JettyUtils extends Logging { * If the desired port number is contented, continues incrementing ports until a free port is * found. Returns the chosen port and the jetty Server object. */ - def startJettyServer(hostName: String, port: Int, handlers: Seq[(String, Handler)]): (Server, Int) - = { - - val handlersToRegister = handlers.map { case(path, handler) => - val contextHandler = new ContextHandler(path) - contextHandler.setHandler(handler) - contextHandler.asInstanceOf[org.eclipse.jetty.server.Handler] - } + def startJettyServer(hostName: String, port: Int, handlers: Seq[ServletContextHandler], + conf: SparkConf): (Server, Int) = { + addFilters(handlers, conf) val handlerList = new HandlerList - handlerList.setHandlers(handlersToRegister.toArray) + handlerList.setHandlers(handlers.toArray) @tailrec def connect(currentPort: Int): (Server, Int) = { @@ -119,7 +165,9 @@ private[spark] object JettyUtils extends Logging { server.setThreadPool(pool) server.setHandler(handlerList) - Try { server.start() } match { + Try { + server.start() + } match { case s: Success[_] => (server, server.getConnectors.head.getLocalPort) case f: Failure[_] => diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index af6b65860e006..ca82c3da2fc24 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -17,7 +17,10 @@ package org.apache.spark.ui -import org.eclipse.jetty.server.{Handler, Server} +import javax.servlet.http.HttpServletRequest + +import org.eclipse.jetty.server.Server +import org.eclipse.jetty.servlet.ServletContextHandler import org.apache.spark.{Logging, SparkContext, SparkEnv} import org.apache.spark.ui.JettyUtils._ @@ -34,9 +37,9 @@ private[spark] class SparkUI(sc: SparkContext) extends Logging { var boundPort: Option[Int] = None var server: Option[Server] = None - val handlers = Seq[(String, Handler)]( - ("/static", createStaticHandler(SparkUI.STATIC_RESOURCE_DIR)), - ("/", createRedirectHandler("/stages")) + val handlers = Seq[ServletContextHandler] ( + createStaticHandler(SparkUI.STATIC_RESOURCE_DIR, "/static/*"), + createRedirectHandler("/stages", "/") ) val storage = new BlockManagerUI(sc) val jobs = new JobProgressUI(sc) @@ -52,7 +55,7 @@ private[spark] class SparkUI(sc: SparkContext) extends Logging { /** Bind the HTTP server which backs this web interface */ def bind() { try { - val (srv, usedPort) = JettyUtils.startJettyServer(host, port, allHandlers) + val (srv, usedPort) = JettyUtils.startJettyServer(host, port, allHandlers, sc.conf) logInfo("Started Spark Web UI at http://%s:%d".format(host, usedPort)) server = Some(srv) boundPort = Some(usedPort) @@ -83,5 +86,5 @@ private[spark] class SparkUI(sc: SparkContext) extends Logging { private[spark] object SparkUI { val DEFAULT_PORT = "4040" - val STATIC_RESOURCE_DIR = "org/apache/spark/ui/static" + val STATIC_RESOURCE_DIR = "org/apache/spark/ui" } diff --git a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentUI.scala b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentUI.scala index 9e7cdc88162e8..14333476c0e31 100644 --- a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentUI.scala @@ -23,7 +23,7 @@ import scala.collection.JavaConversions._ import scala.util.Properties import scala.xml.Node -import org.eclipse.jetty.server.Handler +import org.eclipse.jetty.servlet.ServletContextHandler import org.apache.spark.SparkContext import org.apache.spark.ui.JettyUtils._ @@ -32,8 +32,9 @@ import org.apache.spark.ui.UIUtils private[spark] class EnvironmentUI(sc: SparkContext) { - def getHandlers = Seq[(String, Handler)]( - ("/environment", (request: HttpServletRequest) => envDetails(request)) + def getHandlers = Seq[ServletContextHandler]( + createServletHandler("/environment", + createServlet((request: HttpServletRequest) => envDetails(request), sc.env.securityManager)) ) def envDetails(request: HttpServletRequest): Seq[Node] = { diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala index 1f3b7a4c231b6..4235cfeff9fa2 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala @@ -22,7 +22,7 @@ import javax.servlet.http.HttpServletRequest import scala.collection.mutable.{HashMap, HashSet} import scala.xml.Node -import org.eclipse.jetty.server.Handler +import org.eclipse.jetty.servlet.ServletContextHandler import org.apache.spark.{ExceptionFailure, Logging, SparkContext} import org.apache.spark.executor.TaskMetrics @@ -43,8 +43,9 @@ private[spark] class ExecutorsUI(val sc: SparkContext) { sc.addSparkListener(listener) } - def getHandlers = Seq[(String, Handler)]( - ("/executors", (request: HttpServletRequest) => render(request)) + def getHandlers = Seq[ServletContextHandler]( + createServletHandler("/executors", createServlet((request: HttpServletRequest) => render + (request), sc.env.securityManager)) ) def render(request: HttpServletRequest): Seq[Node] = { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala index 557bce6b66353..2d95d47e154cd 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala @@ -23,6 +23,7 @@ import javax.servlet.http.HttpServletRequest import scala.Seq import org.eclipse.jetty.server.Handler +import org.eclipse.jetty.servlet.ServletContextHandler import org.apache.spark.SparkContext import org.apache.spark.ui.JettyUtils._ @@ -45,9 +46,15 @@ private[spark] class JobProgressUI(val sc: SparkContext) { def formatDuration(ms: Long) = Utils.msDurationToString(ms) - def getHandlers = Seq[(String, Handler)]( - ("/stages/stage", (request: HttpServletRequest) => stagePage.render(request)), - ("/stages/pool", (request: HttpServletRequest) => poolPage.render(request)), - ("/stages", (request: HttpServletRequest) => indexPage.render(request)) + def getHandlers = Seq[ServletContextHandler]( + createServletHandler("/stages/stage", + createServlet((request: HttpServletRequest) => stagePage.render(request), + sc.env.securityManager)), + createServletHandler("/stages/pool", + createServlet((request: HttpServletRequest) => poolPage.render(request), + sc.env.securityManager)), + createServletHandler("/stages", + createServlet((request: HttpServletRequest) => indexPage.render(request), + sc.env.securityManager)) ) } diff --git a/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala b/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala index dc18eab74e0da..cb2083eb019bf 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala @@ -19,7 +19,7 @@ package org.apache.spark.ui.storage import javax.servlet.http.HttpServletRequest -import org.eclipse.jetty.server.Handler +import org.eclipse.jetty.servlet.ServletContextHandler import org.apache.spark.{Logging, SparkContext} import org.apache.spark.ui.JettyUtils._ @@ -29,8 +29,12 @@ private[spark] class BlockManagerUI(val sc: SparkContext) extends Logging { val indexPage = new IndexPage(this) val rddPage = new RDDPage(this) - def getHandlers = Seq[(String, Handler)]( - ("/storage/rdd", (request: HttpServletRequest) => rddPage.render(request)), - ("/storage", (request: HttpServletRequest) => indexPage.render(request)) + def getHandlers = Seq[ServletContextHandler]( + createServletHandler("/storage/rdd", + createServlet((request: HttpServletRequest) => rddPage.render(request), + sc.env.securityManager)), + createServletHandler("/storage", + createServlet((request: HttpServletRequest) => indexPage.render(request), + sc.env.securityManager)) ) } diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala index f26ed47e58046..a6c9a9aaba8eb 100644 --- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala @@ -24,12 +24,12 @@ import akka.actor.{ActorSystem, ExtendedActorSystem, IndestructibleActorSystem} import com.typesafe.config.ConfigFactory import org.apache.log4j.{Level, Logger} -import org.apache.spark.SparkConf +import org.apache.spark.{Logging, SecurityManager, SparkConf} /** * Various utility classes for working with Akka. */ -private[spark] object AkkaUtils { +private[spark] object AkkaUtils extends Logging { /** * Creates an ActorSystem ready for remoting, with various Spark features. Returns both the @@ -42,7 +42,7 @@ private[spark] object AkkaUtils { * of a fatal exception. This is used by [[org.apache.spark.executor.Executor]]. */ def createActorSystem(name: String, host: String, port: Int, indestructible: Boolean = false, - conf: SparkConf): (ActorSystem, Int) = { + conf: SparkConf, securityManager: SecurityManager): (ActorSystem, Int) = { val akkaThreads = conf.getInt("spark.akka.threads", 4) val akkaBatchSize = conf.getInt("spark.akka.batchSize", 15) @@ -65,6 +65,15 @@ private[spark] object AkkaUtils { conf.getDouble("spark.akka.failure-detector.threshold", 300.0) val akkaHeartBeatInterval = conf.getInt("spark.akka.heartbeat.interval", 1000) + val secretKey = securityManager.getSecretKey() + val isAuthOn = securityManager.isAuthenticationEnabled() + if (isAuthOn && secretKey == null) { + throw new Exception("Secret key is null with authentication on") + } + val requireCookie = if (isAuthOn) "on" else "off" + val secureCookie = if (isAuthOn) secretKey else "" + logDebug("In createActorSystem, requireCookie is: " + requireCookie) + val akkaConf = ConfigFactory.parseMap(conf.getAkkaConf.toMap[String, String]).withFallback( ConfigFactory.parseString( s""" @@ -72,6 +81,8 @@ private[spark] object AkkaUtils { |akka.loggers = [""akka.event.slf4j.Slf4jLogger""] |akka.stdout-loglevel = "ERROR" |akka.jvm-exit-on-fatal-error = off + |akka.remote.require-cookie = "$requireCookie" + |akka.remote.secure-cookie = "$secureCookie" |akka.remote.transport-failure-detector.heartbeat-interval = $akkaHeartBeatInterval s |akka.remote.transport-failure-detector.acceptable-heartbeat-pause = $akkaHeartBeatPauses s |akka.remote.transport-failure-detector.threshold = $akkaFailureDetector diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala index 681d0a30cb3f8..a8d20ee332355 100644 --- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala @@ -22,8 +22,8 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream} import scala.collection.mutable.Map import scala.collection.mutable.Set -import org.objectweb.asm.{ClassReader, ClassVisitor, MethodVisitor, Type} -import org.objectweb.asm.Opcodes._ +import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.{ClassReader, ClassVisitor, MethodVisitor, Type} +import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.Opcodes._ import org.apache.spark.Logging diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 8e69f1d3351b5..ac376fc403ada 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -18,7 +18,7 @@ package org.apache.spark.util import java.io._ -import java.net.{InetAddress, Inet4Address, NetworkInterface, URI, URL} +import java.net.{InetAddress, Inet4Address, NetworkInterface, URI, URL, URLConnection} import java.nio.ByteBuffer import java.util.{Locale, Random, UUID} import java.util.concurrent.{ConcurrentHashMap, Executors, ThreadPoolExecutor} @@ -33,10 +33,11 @@ import com.google.common.io.Files import com.google.common.util.concurrent.ThreadFactoryBuilder import org.apache.hadoop.fs.{FileSystem, FileUtil, Path} -import org.apache.spark.{Logging, SparkConf, SparkException} +import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance} import org.apache.spark.deploy.SparkHadoopUtil + /** * Various utility methods used by Spark. */ @@ -232,6 +233,22 @@ private[spark] object Utils extends Logging { } } + /** + * Construct a URI container information used for authentication. + * This also sets the default authenticator to properly negotiation the + * user/password based on the URI. + * + * Note this relies on the Authenticator.setDefault being set properly to decode + * the user name and password. This is currently set in the SecurityManager. + */ + def constructURIForAuthentication(uri: URI, securityMgr: SecurityManager): URI = { + val userCred = securityMgr.getSecretKey() + if (userCred == null) throw new Exception("Secret key is null with authentication on") + val userInfo = securityMgr.getHttpUser() + ":" + userCred + new URI(uri.getScheme(), userInfo, uri.getHost(), uri.getPort(), uri.getPath(), + uri.getQuery(), uri.getFragment()) + } + /** * Download a file requested by the executor. Supports fetching the file in a variety of ways, * including HTTP, HDFS and files on a standard filesystem, based on the URL parameter. @@ -239,7 +256,7 @@ private[spark] object Utils extends Logging { * Throws SparkException if the target file already exists and has different contents than * the requested file. */ - def fetchFile(url: String, targetDir: File, conf: SparkConf) { + def fetchFile(url: String, targetDir: File, conf: SparkConf, securityMgr: SecurityManager) { val filename = url.split("/").last val tempDir = getLocalDir(conf) val tempFile = File.createTempFile("fetchFileTemp", null, new File(tempDir)) @@ -249,7 +266,23 @@ private[spark] object Utils extends Logging { uri.getScheme match { case "http" | "https" | "ftp" => logInfo("Fetching " + url + " to " + tempFile) - val in = new URL(url).openStream() + + var uc: URLConnection = null + if (securityMgr.isAuthenticationEnabled()) { + logDebug("fetchFile with security enabled") + val newuri = constructURIForAuthentication(uri, securityMgr) + uc = newuri.toURL().openConnection() + uc.setAllowUserInteraction(false) + } else { + logDebug("fetchFile not using security") + uc = new URL(url).openConnection() + } + + val timeout = conf.getInt("spark.files.fetchTimeout", 60) * 1000 + uc.setConnectTimeout(timeout) + uc.setReadTimeout(timeout) + uc.connect() + val in = uc.getInputStream(); val out = new FileOutputStream(tempFile) Utils.copyStream(in, out, true) if (targetFile.exists && !Files.equal(tempFile, targetFile)) { @@ -503,8 +536,6 @@ private[spark] object Utils extends Logging { /** * Convert a Java memory parameter passed to -Xmx (such as 300m or 1g) to a number of megabytes. - * This is used to figure out how much memory to claim from Mesos based on the SPARK_MEM - * environment variable. */ def memoryStringToMb(str: String): Int = { val lower = str.toLowerCase diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index c7d0e2d577726..40e853c39ca99 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -110,6 +110,37 @@ public void sparkContextUnion() { Assert.assertEquals(4, pUnion.count()); } + @SuppressWarnings("unchecked") + @Test + public void intersection() { + List ints1 = Arrays.asList(1, 10, 2, 3, 4, 5); + List ints2 = Arrays.asList(1, 6, 2, 3, 7, 8); + JavaRDD s1 = sc.parallelize(ints1); + JavaRDD s2 = sc.parallelize(ints2); + + JavaRDD intersections = s1.intersection(s2); + Assert.assertEquals(3, intersections.count()); + + ArrayList list = new ArrayList(); + JavaRDD empty = sc.parallelize(list); + JavaRDD emptyIntersection = empty.intersection(s2); + Assert.assertEquals(0, emptyIntersection.count()); + + List doubles = Arrays.asList(1.0, 2.0); + JavaDoubleRDD d1 = sc.parallelizeDoubles(doubles); + JavaDoubleRDD d2 = sc.parallelizeDoubles(doubles); + JavaDoubleRDD dIntersection = d1.intersection(d2); + Assert.assertEquals(2, dIntersection.count()); + + List> pairs = new ArrayList>(); + pairs.add(new Tuple2(1, 2)); + pairs.add(new Tuple2(3, 4)); + JavaPairRDD p1 = sc.parallelizePairs(pairs); + JavaPairRDD p2 = sc.parallelizePairs(pairs); + JavaPairRDD pIntersection = p1.intersection(p2); + Assert.assertEquals(2, pIntersection.count()); + } + @Test public void sortByKey() { List> pairs = new ArrayList>(); diff --git a/core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala b/core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala new file mode 100644 index 0000000000000..cd054c1f684ab --- /dev/null +++ b/core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala @@ -0,0 +1,215 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.scalatest.FunSuite + +import akka.actor._ +import org.apache.spark.scheduler.MapStatus +import org.apache.spark.storage.BlockManagerId +import org.apache.spark.util.AkkaUtils +import scala.concurrent.Await + +/** + * Test the AkkaUtils with various security settings. + */ +class AkkaUtilsSuite extends FunSuite with LocalSparkContext { + + test("remote fetch security bad password") { + val conf = new SparkConf + conf.set("spark.authenticate", "true") + conf.set("spark.authenticate.secret", "good") + + val securityManager = new SecurityManager(conf); + val hostname = "localhost" + val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, + conf = conf, securityManager = securityManager) + System.setProperty("spark.driver.port", boundPort.toString) // Will be cleared by LocalSparkContext + System.setProperty("spark.hostPort", hostname + ":" + boundPort) + assert(securityManager.isAuthenticationEnabled() === true) + + val masterTracker = new MapOutputTrackerMaster(conf) + masterTracker.trackerActor = actorSystem.actorOf( + Props(new MapOutputTrackerMasterActor(masterTracker)), "MapOutputTracker") + + val badconf = new SparkConf + badconf.set("spark.authenticate", "true") + badconf.set("spark.authenticate.secret", "bad") + val securityManagerBad = new SecurityManager(badconf); + + assert(securityManagerBad.isAuthenticationEnabled() === true) + + val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, + conf = conf, securityManager = securityManagerBad) + val slaveTracker = new MapOutputTracker(conf) + val selection = slaveSystem.actorSelection( + s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker") + val timeout = AkkaUtils.lookupTimeout(conf) + intercept[akka.actor.ActorNotFound] { + slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) + } + + actorSystem.shutdown() + slaveSystem.shutdown() + } + + test("remote fetch security off") { + val conf = new SparkConf + conf.set("spark.authenticate", "false") + conf.set("spark.authenticate.secret", "bad") + val securityManager = new SecurityManager(conf); + + val hostname = "localhost" + val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, + conf = conf, securityManager = securityManager) + System.setProperty("spark.driver.port", boundPort.toString) // Will be cleared by LocalSparkContext + System.setProperty("spark.hostPort", hostname + ":" + boundPort) + + assert(securityManager.isAuthenticationEnabled() === false) + + val masterTracker = new MapOutputTrackerMaster(conf) + masterTracker.trackerActor = actorSystem.actorOf( + Props(new MapOutputTrackerMasterActor(masterTracker)), "MapOutputTracker") + + val badconf = new SparkConf + badconf.set("spark.authenticate", "false") + badconf.set("spark.authenticate.secret", "good") + val securityManagerBad = new SecurityManager(badconf); + + val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, + conf = badconf, securityManager = securityManagerBad) + val slaveTracker = new MapOutputTracker(conf) + val selection = slaveSystem.actorSelection( + s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker") + val timeout = AkkaUtils.lookupTimeout(conf) + slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) + + assert(securityManagerBad.isAuthenticationEnabled() === false) + + masterTracker.registerShuffle(10, 1) + masterTracker.incrementEpoch() + slaveTracker.updateEpoch(masterTracker.getEpoch) + + val compressedSize1000 = MapOutputTracker.compressSize(1000L) + val size1000 = MapOutputTracker.decompressSize(compressedSize1000) + masterTracker.registerMapOutput(10, 0, new MapStatus( + BlockManagerId("a", "hostA", 1000, 0), Array(compressedSize1000))) + masterTracker.incrementEpoch() + slaveTracker.updateEpoch(masterTracker.getEpoch) + + // this should succeed since security off + assert(slaveTracker.getServerStatuses(10, 0).toSeq === + Seq((BlockManagerId("a", "hostA", 1000, 0), size1000))) + + actorSystem.shutdown() + slaveSystem.shutdown() + } + + test("remote fetch security pass") { + val conf = new SparkConf + conf.set("spark.authenticate", "true") + conf.set("spark.authenticate.secret", "good") + val securityManager = new SecurityManager(conf); + + val hostname = "localhost" + val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, + conf = conf, securityManager = securityManager) + System.setProperty("spark.driver.port", boundPort.toString) // Will be cleared by LocalSparkContext + System.setProperty("spark.hostPort", hostname + ":" + boundPort) + + assert(securityManager.isAuthenticationEnabled() === true) + + val masterTracker = new MapOutputTrackerMaster(conf) + masterTracker.trackerActor = actorSystem.actorOf( + Props(new MapOutputTrackerMasterActor(masterTracker)), "MapOutputTracker") + + val goodconf = new SparkConf + goodconf.set("spark.authenticate", "true") + goodconf.set("spark.authenticate.secret", "good") + val securityManagerGood = new SecurityManager(goodconf); + + assert(securityManagerGood.isAuthenticationEnabled() === true) + + val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, + conf = goodconf, securityManager = securityManagerGood) + val slaveTracker = new MapOutputTracker(conf) + val selection = slaveSystem.actorSelection( + s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker") + val timeout = AkkaUtils.lookupTimeout(conf) + slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) + + masterTracker.registerShuffle(10, 1) + masterTracker.incrementEpoch() + slaveTracker.updateEpoch(masterTracker.getEpoch) + + val compressedSize1000 = MapOutputTracker.compressSize(1000L) + val size1000 = MapOutputTracker.decompressSize(compressedSize1000) + masterTracker.registerMapOutput(10, 0, new MapStatus( + BlockManagerId("a", "hostA", 1000, 0), Array(compressedSize1000))) + masterTracker.incrementEpoch() + slaveTracker.updateEpoch(masterTracker.getEpoch) + + // this should succeed since security on and passwords match + assert(slaveTracker.getServerStatuses(10, 0).toSeq === + Seq((BlockManagerId("a", "hostA", 1000, 0), size1000))) + + actorSystem.shutdown() + slaveSystem.shutdown() + } + + test("remote fetch security off client") { + val conf = new SparkConf + conf.set("spark.authenticate", "true") + conf.set("spark.authenticate.secret", "good") + + val securityManager = new SecurityManager(conf); + + val hostname = "localhost" + val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, + conf = conf, securityManager = securityManager) + System.setProperty("spark.driver.port", boundPort.toString) // Will be cleared by LocalSparkContext + System.setProperty("spark.hostPort", hostname + ":" + boundPort) + + assert(securityManager.isAuthenticationEnabled() === true) + + val masterTracker = new MapOutputTrackerMaster(conf) + masterTracker.trackerActor = actorSystem.actorOf( + Props(new MapOutputTrackerMasterActor(masterTracker)), "MapOutputTracker") + + val badconf = new SparkConf + badconf.set("spark.authenticate", "false") + badconf.set("spark.authenticate.secret", "bad") + val securityManagerBad = new SecurityManager(badconf); + + assert(securityManagerBad.isAuthenticationEnabled() === false) + + val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, + conf = badconf, securityManager = securityManagerBad) + val slaveTracker = new MapOutputTracker(conf) + val selection = slaveSystem.actorSelection( + s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker") + val timeout = AkkaUtils.lookupTimeout(conf) + intercept[akka.actor.ActorNotFound] { + slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) + } + + actorSystem.shutdown() + slaveSystem.shutdown() + } + +} diff --git a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala index e022accee6d08..96ba3929c1685 100644 --- a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala @@ -21,6 +21,7 @@ import org.scalatest.FunSuite class BroadcastSuite extends FunSuite with LocalSparkContext { + override def afterEach() { super.afterEach() System.clearProperty("spark.broadcast.factory") diff --git a/core/src/test/scala/org/apache/spark/ConnectionManagerSuite.scala b/core/src/test/scala/org/apache/spark/ConnectionManagerSuite.scala new file mode 100644 index 0000000000000..80f7ec00c74b2 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/ConnectionManagerSuite.scala @@ -0,0 +1,230 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.scalatest.FunSuite + +import java.nio._ + +import org.apache.spark.network.{ConnectionManager, Message, ConnectionManagerId} +import scala.concurrent.Await +import scala.concurrent.TimeoutException +import scala.concurrent.duration._ + + +/** + * Test the ConnectionManager with various security settings. + */ +class ConnectionManagerSuite extends FunSuite { + + test("security default off") { + val conf = new SparkConf + val securityManager = new SecurityManager(conf) + val manager = new ConnectionManager(0, conf, securityManager) + var receivedMessage = false + manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { + receivedMessage = true + None + }) + + val size = 10 * 1024 * 1024 + val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) + buffer.flip + + val bufferMessage = Message.createBufferMessage(buffer.duplicate) + manager.sendMessageReliablySync(manager.id, bufferMessage) + + assert(receivedMessage == true) + + manager.stop() + } + + test("security on same password") { + val conf = new SparkConf + conf.set("spark.authenticate", "true") + conf.set("spark.authenticate.secret", "good") + val securityManager = new SecurityManager(conf) + val manager = new ConnectionManager(0, conf, securityManager) + var numReceivedMessages = 0 + + manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { + numReceivedMessages += 1 + None + }) + val managerServer = new ConnectionManager(0, conf, securityManager) + var numReceivedServerMessages = 0 + managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { + numReceivedServerMessages += 1 + None + }) + + val size = 10 * 1024 * 1024 + val count = 10 + val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) + buffer.flip + + (0 until count).map(i => { + val bufferMessage = Message.createBufferMessage(buffer.duplicate) + manager.sendMessageReliablySync(managerServer.id, bufferMessage) + }) + + assert(numReceivedServerMessages == 10) + assert(numReceivedMessages == 0) + + manager.stop() + managerServer.stop() + } + + test("security mismatch password") { + val conf = new SparkConf + conf.set("spark.authenticate", "true") + conf.set("spark.authenticate.secret", "good") + val securityManager = new SecurityManager(conf) + val manager = new ConnectionManager(0, conf, securityManager) + var numReceivedMessages = 0 + + manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { + numReceivedMessages += 1 + None + }) + + val badconf = new SparkConf + badconf.set("spark.authenticate", "true") + badconf.set("spark.authenticate.secret", "bad") + val badsecurityManager = new SecurityManager(badconf) + val managerServer = new ConnectionManager(0, badconf, badsecurityManager) + var numReceivedServerMessages = 0 + + managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { + numReceivedServerMessages += 1 + None + }) + + val size = 10 * 1024 * 1024 + val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) + buffer.flip + val bufferMessage = Message.createBufferMessage(buffer.duplicate) + manager.sendMessageReliablySync(managerServer.id, bufferMessage) + + assert(numReceivedServerMessages == 0) + assert(numReceivedMessages == 0) + + manager.stop() + managerServer.stop() + } + + test("security mismatch auth off") { + val conf = new SparkConf + conf.set("spark.authenticate", "false") + conf.set("spark.authenticate.secret", "good") + val securityManager = new SecurityManager(conf) + val manager = new ConnectionManager(0, conf, securityManager) + var numReceivedMessages = 0 + + manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { + numReceivedMessages += 1 + None + }) + + val badconf = new SparkConf + badconf.set("spark.authenticate", "true") + badconf.set("spark.authenticate.secret", "good") + val badsecurityManager = new SecurityManager(badconf) + val managerServer = new ConnectionManager(0, badconf, badsecurityManager) + var numReceivedServerMessages = 0 + managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { + numReceivedServerMessages += 1 + None + }) + + val size = 10 * 1024 * 1024 + val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) + buffer.flip + val bufferMessage = Message.createBufferMessage(buffer.duplicate) + (0 until 1).map(i => { + val bufferMessage = Message.createBufferMessage(buffer.duplicate) + manager.sendMessageReliably(managerServer.id, bufferMessage) + }).foreach(f => { + try { + val g = Await.result(f, 1 second) + assert(false) + } catch { + case e: TimeoutException => { + // we should timeout here since the client can't do the negotiation + assert(true) + } + } + }) + + assert(numReceivedServerMessages == 0) + assert(numReceivedMessages == 0) + manager.stop() + managerServer.stop() + } + + test("security auth off") { + val conf = new SparkConf + conf.set("spark.authenticate", "false") + val securityManager = new SecurityManager(conf) + val manager = new ConnectionManager(0, conf, securityManager) + var numReceivedMessages = 0 + + manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { + numReceivedMessages += 1 + None + }) + + val badconf = new SparkConf + badconf.set("spark.authenticate", "false") + val badsecurityManager = new SecurityManager(badconf) + val managerServer = new ConnectionManager(0, badconf, badsecurityManager) + var numReceivedServerMessages = 0 + + managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { + numReceivedServerMessages += 1 + None + }) + + val size = 10 * 1024 * 1024 + val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) + buffer.flip + val bufferMessage = Message.createBufferMessage(buffer.duplicate) + (0 until 10).map(i => { + val bufferMessage = Message.createBufferMessage(buffer.duplicate) + manager.sendMessageReliably(managerServer.id, bufferMessage) + }).foreach(f => { + try { + val g = Await.result(f, 1 second) + if (!g.isDefined) assert(false) else assert(true) + } catch { + case e: Exception => { + assert(false) + } + } + }) + assert(numReceivedServerMessages == 10) + assert(numReceivedMessages == 0) + + manager.stop() + managerServer.stop() + } + + + +} + diff --git a/core/src/test/scala/org/apache/spark/DriverSuite.scala b/core/src/test/scala/org/apache/spark/DriverSuite.scala index e0e8011278649..9cbdfc54a3dc8 100644 --- a/core/src/test/scala/org/apache/spark/DriverSuite.scala +++ b/core/src/test/scala/org/apache/spark/DriverSuite.scala @@ -30,6 +30,7 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.util.Utils class DriverSuite extends FunSuite with Timeouts { + test("driver should exit after finishing") { val sparkHome = sys.env.get("SPARK_HOME").orElse(sys.props.get("spark.home")).get // Regression test for SPARK-530: "Spark driver process doesn't exit after finishing" diff --git a/core/src/test/scala/org/apache/spark/FileServerSuite.scala b/core/src/test/scala/org/apache/spark/FileServerSuite.scala index 9be67b3c95abd..aee9ab9091dac 100644 --- a/core/src/test/scala/org/apache/spark/FileServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileServerSuite.scala @@ -30,6 +30,12 @@ class FileServerSuite extends FunSuite with LocalSparkContext { @transient var tmpFile: File = _ @transient var tmpJarUrl: String = _ + override def beforeEach() { + super.beforeEach() + resetSparkContext() + System.setProperty("spark.authenticate", "false") + } + override def beforeAll() { super.beforeAll() val tmpDir = new File(Files.createTempDir(), "test") @@ -43,6 +49,7 @@ class FileServerSuite extends FunSuite with LocalSparkContext { val jarFile = new File(tmpDir, "test.jar") val jarStream = new FileOutputStream(jarFile) val jar = new JarOutputStream(jarStream, new java.util.jar.Manifest()) + System.setProperty("spark.authenticate", "false") val jarEntry = new JarEntry(textFile.getName) jar.putNextEntry(jarEntry) @@ -77,6 +84,25 @@ class FileServerSuite extends FunSuite with LocalSparkContext { assert(result.toSet === Set((1,200), (2,300), (3,500))) } + test("Distributing files locally security On") { + val sparkConf = new SparkConf(false) + sparkConf.set("spark.authenticate", "true") + sparkConf.set("spark.authenticate.secret", "good") + sc = new SparkContext("local[4]", "test", sparkConf) + + sc.addFile(tmpFile.toString) + assert(sc.env.securityManager.isAuthenticationEnabled() === true) + val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0)) + val result = sc.parallelize(testData).reduceByKey { + val path = SparkFiles.get("FileServerSuite.txt") + val in = new BufferedReader(new FileReader(path)) + val fileVal = in.readLine().toInt + in.close() + _ * fileVal + _ * fileVal + }.collect() + assert(result.toSet === Set((1,200), (2,300), (3,500))) + } + test("Distributing files locally using URL as input") { // addFile("file:///....") sc = new SparkContext("local[4]", "test") diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 6c1e325f6f348..8efa072a97911 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -98,14 +98,16 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { test("remote fetch") { val hostname = "localhost" - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, conf = conf) + val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, conf = conf, + securityManager = new SecurityManager(conf)) System.setProperty("spark.driver.port", boundPort.toString) // Will be cleared by LocalSparkContext val masterTracker = new MapOutputTrackerMaster(conf) masterTracker.trackerActor = actorSystem.actorOf( Props(new MapOutputTrackerMasterActor(masterTracker)), "MapOutputTracker") - val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, conf = conf) + val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, conf = conf, + securityManager = new SecurityManager(conf)) val slaveTracker = new MapOutputTracker(conf) val selection = slaveSystem.actorSelection( s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker") diff --git a/core/src/test/scala/org/apache/spark/PipedRDDSuite.scala b/core/src/test/scala/org/apache/spark/PipedRDDSuite.scala index 3a0385a1b0bd9..0bac78d8a6bdf 100644 --- a/core/src/test/scala/org/apache/spark/PipedRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/PipedRDDSuite.scala @@ -19,74 +19,152 @@ package org.apache.spark import org.scalatest.FunSuite + +import org.apache.spark.rdd.{HadoopRDD, PipedRDD, HadoopPartition} +import org.apache.hadoop.mapred.{JobConf, TextInputFormat, FileSplit} +import org.apache.hadoop.fs.Path + +import scala.collection.Map +import scala.sys.process._ +import scala.util.Try +import org.apache.hadoop.io.{Text, LongWritable} + class PipedRDDSuite extends FunSuite with SharedSparkContext { test("basic pipe") { - val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) + if (testCommandAvailable("cat")) { + val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) - val piped = nums.pipe(Seq("cat")) + val piped = nums.pipe(Seq("cat")) - val c = piped.collect() - assert(c.size === 4) - assert(c(0) === "1") - assert(c(1) === "2") - assert(c(2) === "3") - assert(c(3) === "4") + val c = piped.collect() + assert(c.size === 4) + assert(c(0) === "1") + assert(c(1) === "2") + assert(c(2) === "3") + assert(c(3) === "4") + } else { + assert(true) + } } test("advanced pipe") { - val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) - val bl = sc.broadcast(List("0")) - - val piped = nums.pipe(Seq("cat"), - Map[String, String](), - (f: String => Unit) => {bl.value.map(f(_));f("\u0001")}, - (i:Int, f: String=> Unit) => f(i + "_")) - - val c = piped.collect() - - assert(c.size === 8) - assert(c(0) === "0") - assert(c(1) === "\u0001") - assert(c(2) === "1_") - assert(c(3) === "2_") - assert(c(4) === "0") - assert(c(5) === "\u0001") - assert(c(6) === "3_") - assert(c(7) === "4_") - - val nums1 = sc.makeRDD(Array("a\t1", "b\t2", "a\t3", "b\t4"), 2) - val d = nums1.groupBy(str=>str.split("\t")(0)). - pipe(Seq("cat"), - Map[String, String](), - (f: String => Unit) => {bl.value.map(f(_));f("\u0001")}, - (i:Tuple2[String, Seq[String]], f: String=> Unit) => {for (e <- i._2){ f(e + "_")}}).collect() - assert(d.size === 8) - assert(d(0) === "0") - assert(d(1) === "\u0001") - assert(d(2) === "b\t2_") - assert(d(3) === "b\t4_") - assert(d(4) === "0") - assert(d(5) === "\u0001") - assert(d(6) === "a\t1_") - assert(d(7) === "a\t3_") + if (testCommandAvailable("cat")) { + val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) + val bl = sc.broadcast(List("0")) + + val piped = nums.pipe(Seq("cat"), + Map[String, String](), + (f: String => Unit) => { + bl.value.map(f(_)); f("\u0001") + }, + (i: Int, f: String => Unit) => f(i + "_")) + + val c = piped.collect() + + assert(c.size === 8) + assert(c(0) === "0") + assert(c(1) === "\u0001") + assert(c(2) === "1_") + assert(c(3) === "2_") + assert(c(4) === "0") + assert(c(5) === "\u0001") + assert(c(6) === "3_") + assert(c(7) === "4_") + + val nums1 = sc.makeRDD(Array("a\t1", "b\t2", "a\t3", "b\t4"), 2) + val d = nums1.groupBy(str => str.split("\t")(0)). + pipe(Seq("cat"), + Map[String, String](), + (f: String => Unit) => { + bl.value.map(f(_)); f("\u0001") + }, + (i: Tuple2[String, Seq[String]], f: String => Unit) => { + for (e <- i._2) { + f(e + "_") + } + }).collect() + assert(d.size === 8) + assert(d(0) === "0") + assert(d(1) === "\u0001") + assert(d(2) === "b\t2_") + assert(d(3) === "b\t4_") + assert(d(4) === "0") + assert(d(5) === "\u0001") + assert(d(6) === "a\t1_") + assert(d(7) === "a\t3_") + } else { + assert(true) + } } test("pipe with env variable") { - val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) - val piped = nums.pipe(Seq("printenv", "MY_TEST_ENV"), Map("MY_TEST_ENV" -> "LALALA")) - val c = piped.collect() - assert(c.size === 2) - assert(c(0) === "LALALA") - assert(c(1) === "LALALA") + if (testCommandAvailable("printenv")) { + val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) + val piped = nums.pipe(Seq("printenv", "MY_TEST_ENV"), Map("MY_TEST_ENV" -> "LALALA")) + val c = piped.collect() + assert(c.size === 2) + assert(c(0) === "LALALA") + assert(c(1) === "LALALA") + } else { + assert(true) + } } test("pipe with non-zero exit status") { - val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) - val piped = nums.pipe(Seq("cat nonexistent_file", "2>", "/dev/null")) - intercept[SparkException] { - piped.collect() + if (testCommandAvailable("cat")) { + val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) + val piped = nums.pipe(Seq("cat nonexistent_file", "2>", "/dev/null")) + intercept[SparkException] { + piped.collect() + } + } else { + assert(true) } } + test("test pipe exports map_input_file") { + testExportInputFile("map_input_file") + } + + test("test pipe exports mapreduce_map_input_file") { + testExportInputFile("mapreduce_map_input_file") + } + + def testCommandAvailable(command: String): Boolean = { + Try(Process(command) !!).isSuccess + } + + def testExportInputFile(varName: String) { + if (testCommandAvailable("printenv")) { + val nums = new HadoopRDD(sc, new JobConf(), classOf[TextInputFormat], classOf[LongWritable], + classOf[Text], 2) { + override def getPartitions: Array[Partition] = Array(generateFakeHadoopPartition()) + + override val getDependencies = List[Dependency[_]]() + + override def compute(theSplit: Partition, context: TaskContext) = { + new InterruptibleIterator[(LongWritable, Text)](context, Iterator((new LongWritable(1), + new Text("b")))) + } + } + val hadoopPart1 = generateFakeHadoopPartition() + val pipedRdd = new PipedRDD(nums, "printenv " + varName) + val tContext = new TaskContext(0, 0, 0, interrupted = false, runningLocally = false, + taskMetrics = null) + val rddIter = pipedRdd.compute(hadoopPart1, tContext) + val arr = rddIter.toArray + assert(arr(0) == "/some/path") + } else { + // printenv isn't available so just pass the test + assert(true) + } + } + + def generateFakeHadoopPartition(): HadoopPartition = { + val split = new FileSplit(new Path("/some/path"), 0, 1, + Array[String]("loc1", "loc2", "loc3", "loc4", "loc5")) + new HadoopPartition(sc.newRddId(), 1, split) + } + } diff --git a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala index f28d5c7b133b3..3bb936790d506 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala @@ -95,6 +95,10 @@ class SparkContextSchedulerCreationSuite } } + test("yarn-cluster") { + testYarn("yarn-cluster", "org.apache.spark.scheduler.cluster.YarnClusterScheduler") + } + test("yarn-standalone") { testYarn("yarn-standalone", "org.apache.spark.scheduler.cluster.YarnClusterScheduler") } diff --git a/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala b/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala index c1e8b295dfe3b..96a5a1231813e 100644 --- a/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala @@ -18,21 +18,22 @@ package org.apache.spark.metrics import org.scalatest.{BeforeAndAfter, FunSuite} - -import org.apache.spark.SparkConf +import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.deploy.master.MasterSource class MetricsSystemSuite extends FunSuite with BeforeAndAfter { var filePath: String = _ var conf: SparkConf = null + var securityMgr: SecurityManager = null before { filePath = getClass.getClassLoader.getResource("test_metrics_system.properties").getFile() conf = new SparkConf(false).set("spark.metrics.conf", filePath) + securityMgr = new SecurityManager(conf) } test("MetricsSystem with default config") { - val metricsSystem = MetricsSystem.createMetricsSystem("default", conf) + val metricsSystem = MetricsSystem.createMetricsSystem("default", conf, securityMgr) val sources = metricsSystem.sources val sinks = metricsSystem.sinks @@ -42,7 +43,7 @@ class MetricsSystemSuite extends FunSuite with BeforeAndAfter { } test("MetricsSystem with sources add") { - val metricsSystem = MetricsSystem.createMetricsSystem("test", conf) + val metricsSystem = MetricsSystem.createMetricsSystem("test", conf, securityMgr) val sources = metricsSystem.sources val sinks = metricsSystem.sinks diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 9f011d9c8d132..1036b9f34e9dd 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -28,7 +28,7 @@ import org.scalatest.concurrent.Timeouts._ import org.scalatest.matchers.ShouldMatchers._ import org.scalatest.time.SpanSugar._ -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.{SecurityManager, SparkConf, SparkContext} import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.util.{AkkaUtils, ByteBufferInputStream, SizeEstimator, Utils} @@ -39,6 +39,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT var actorSystem: ActorSystem = null var master: BlockManagerMaster = null var oldArch: String = null + conf.set("spark.authenticate", "false") + val securityMgr = new SecurityManager(conf) // Reuse a serializer across tests to avoid creating a new thread-local buffer on each test conf.set("spark.kryoserializer.buffer.mb", "1") @@ -49,7 +51,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT def rdd(rddId: Int, splitId: Int) = RDDBlockId(rddId, splitId) before { - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("test", "localhost", 0, conf = conf) + val (actorSystem, boundPort) = AkkaUtils.createActorSystem("test", "localhost", 0, conf = conf, + securityManager = securityMgr) this.actorSystem = actorSystem conf.set("spark.driver.port", boundPort.toString) @@ -125,7 +128,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("master + 1 manager interaction") { - store = new BlockManager("", actorSystem, master, serializer, 2000, conf) + store = new BlockManager("", actorSystem, master, serializer, 2000, conf, securityMgr) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -155,8 +158,9 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("master + 2 managers interaction") { - store = new BlockManager("exec1", actorSystem, master, serializer, 2000, conf) - store2 = new BlockManager("exec2", actorSystem, master, new KryoSerializer(conf), 2000, conf) + store = new BlockManager("exec1", actorSystem, master, serializer, 2000, conf, securityMgr) + store2 = new BlockManager("exec2", actorSystem, master, new KryoSerializer(conf), 2000, conf, + securityMgr) val peers = master.getPeers(store.blockManagerId, 1) assert(peers.size === 1, "master did not return the other manager as a peer") @@ -171,7 +175,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("removing block") { - store = new BlockManager("", actorSystem, master, serializer, 2000, conf) + store = new BlockManager("", actorSystem, master, serializer, 2000, conf, securityMgr) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -219,7 +223,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("removing rdd") { - store = new BlockManager("", actorSystem, master, serializer, 2000, conf) + store = new BlockManager("", actorSystem, master, serializer, 2000, conf, securityMgr) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -253,7 +257,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT test("reregistration on heart beat") { val heartBeat = PrivateMethod[Unit]('heartBeat) - store = new BlockManager("", actorSystem, master, serializer, 2000, conf) + store = new BlockManager("", actorSystem, master, serializer, 2000, conf, securityMgr) val a1 = new Array[Byte](400) store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) @@ -269,7 +273,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("reregistration on block update") { - store = new BlockManager("", actorSystem, master, serializer, 2000, conf) + store = new BlockManager("", actorSystem, master, serializer, 2000, conf, securityMgr) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) @@ -288,7 +292,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT test("reregistration doesn't dead lock") { val heartBeat = PrivateMethod[Unit]('heartBeat) - store = new BlockManager("", actorSystem, master, serializer, 2000, conf) + store = new BlockManager("", actorSystem, master, serializer, 2000, conf, securityMgr) val a1 = new Array[Byte](400) val a2 = List(new Array[Byte](400)) @@ -325,7 +329,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("in-memory LRU storage") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf) + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -344,7 +348,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("in-memory LRU storage with serialization") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf) + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -363,7 +367,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("in-memory LRU for partitions of same RDD") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf) + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -382,7 +386,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("in-memory LRU for partitions of multiple RDDs") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf) + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) store.putSingle(rdd(0, 1), new Array[Byte](400), StorageLevel.MEMORY_ONLY) store.putSingle(rdd(0, 2), new Array[Byte](400), StorageLevel.MEMORY_ONLY) store.putSingle(rdd(1, 1), new Array[Byte](400), StorageLevel.MEMORY_ONLY) @@ -405,7 +409,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("on-disk storage") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf) + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -418,7 +422,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("disk and memory storage") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf) + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -433,7 +437,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("disk and memory storage with getLocalBytes") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf) + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -448,7 +452,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("disk and memory storage with serialization") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf) + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -463,7 +467,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("disk and memory storage with serialization and getLocalBytes") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf) + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -478,7 +482,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("LRU with mixed storage levels") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf) + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -503,7 +507,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("in-memory LRU with streams") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf) + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) val list1 = List(new Array[Byte](200), new Array[Byte](200)) val list2 = List(new Array[Byte](200), new Array[Byte](200)) val list3 = List(new Array[Byte](200), new Array[Byte](200)) @@ -527,7 +531,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("LRU with mixed storage levels and streams") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf) + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) val list1 = List(new Array[Byte](200), new Array[Byte](200)) val list2 = List(new Array[Byte](200), new Array[Byte](200)) val list3 = List(new Array[Byte](200), new Array[Byte](200)) @@ -573,7 +577,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("overly large block") { - store = new BlockManager("", actorSystem, master, serializer, 500, conf) + store = new BlockManager("", actorSystem, master, serializer, 500, conf, securityMgr) store.putSingle("a1", new Array[Byte](1000), StorageLevel.MEMORY_ONLY) assert(store.getSingle("a1") === None, "a1 was in store") store.putSingle("a2", new Array[Byte](1000), StorageLevel.MEMORY_AND_DISK) @@ -584,7 +588,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT test("block compression") { try { conf.set("spark.shuffle.compress", "true") - store = new BlockManager("exec1", actorSystem, master, serializer, 2000, conf) + store = new BlockManager("exec1", actorSystem, master, serializer, 2000, conf, securityMgr) store.putSingle(ShuffleBlockId(0, 0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize(ShuffleBlockId(0, 0, 0)) <= 100, "shuffle_0_0_0 was not compressed") @@ -592,7 +596,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT store = null conf.set("spark.shuffle.compress", "false") - store = new BlockManager("exec2", actorSystem, master, serializer, 2000, conf) + store = new BlockManager("exec2", actorSystem, master, serializer, 2000, conf, securityMgr) store.putSingle(ShuffleBlockId(0, 0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize(ShuffleBlockId(0, 0, 0)) >= 1000, "shuffle_0_0_0 was compressed") @@ -600,7 +604,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT store = null conf.set("spark.broadcast.compress", "true") - store = new BlockManager("exec3", actorSystem, master, serializer, 2000, conf) + store = new BlockManager("exec3", actorSystem, master, serializer, 2000, conf, securityMgr) store.putSingle(BroadcastBlockId(0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize(BroadcastBlockId(0)) <= 100, "broadcast_0 was not compressed") @@ -608,28 +612,28 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT store = null conf.set("spark.broadcast.compress", "false") - store = new BlockManager("exec4", actorSystem, master, serializer, 2000, conf) + store = new BlockManager("exec4", actorSystem, master, serializer, 2000, conf, securityMgr) store.putSingle(BroadcastBlockId(0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize(BroadcastBlockId(0)) >= 1000, "broadcast_0 was compressed") store.stop() store = null conf.set("spark.rdd.compress", "true") - store = new BlockManager("exec5", actorSystem, master, serializer, 2000, conf) + store = new BlockManager("exec5", actorSystem, master, serializer, 2000, conf, securityMgr) store.putSingle(rdd(0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize(rdd(0, 0)) <= 100, "rdd_0_0 was not compressed") store.stop() store = null conf.set("spark.rdd.compress", "false") - store = new BlockManager("exec6", actorSystem, master, serializer, 2000, conf) + store = new BlockManager("exec6", actorSystem, master, serializer, 2000, conf, securityMgr) store.putSingle(rdd(0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize(rdd(0, 0)) >= 1000, "rdd_0_0 was compressed") store.stop() store = null // Check that any other block types are also kept uncompressed - store = new BlockManager("exec7", actorSystem, master, serializer, 2000, conf) + store = new BlockManager("exec7", actorSystem, master, serializer, 2000, conf, securityMgr) store.putSingle("other_block", new Array[Byte](1000), StorageLevel.MEMORY_ONLY) assert(store.memoryStore.getSize("other_block") >= 1000, "other_block was compressed") store.stop() @@ -643,7 +647,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT test("block store put failure") { // Use Java serializer so we can create an unserializable error. - store = new BlockManager("", actorSystem, master, new JavaSerializer(conf), 1200, conf) + store = new BlockManager("", actorSystem, master, new JavaSerializer(conf), 1200, conf, + securityMgr) // The put should fail since a1 is not serializable. class UnserializableClass @@ -657,4 +662,18 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT assert(store.getSingle("a1") == None, "a1 should not be in store") } } + + test("SPARK-1194 regression: fix the same-RDD rule for cache replacement") { + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) + store.putSingle(rdd(0, 0), new Array[Byte](400), StorageLevel.MEMORY_ONLY) + store.putSingle(rdd(1, 0), new Array[Byte](400), StorageLevel.MEMORY_ONLY) + // Access rdd_1_0 to ensure it's not least recently used. + assert(store.getSingle(rdd(1, 0)).isDefined, "rdd_1_0 was not in store") + // According to the same-RDD rule, rdd_1_0 should be replaced here. + store.putSingle(rdd(0, 1), new Array[Byte](400), StorageLevel.MEMORY_ONLY) + // rdd_1_0 should have been replaced, even it's not least recently used. + assert(store.memoryStore.contains(rdd(0, 0)), "rdd_0_0 was not in store") + assert(store.memoryStore.contains(rdd(0, 1)), "rdd_0_1 was not in store") + assert(!store.memoryStore.contains(rdd(1, 0)), "rdd_1_0 was in store") + } } diff --git a/core/src/test/scala/org/apache/spark/storage/FlatmapIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/FlatmapIteratorSuite.scala new file mode 100644 index 0000000000000..bcf138b5ee6d0 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/storage/FlatmapIteratorSuite.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.storage + +import org.scalatest.FunSuite +import org.apache.spark.{SharedSparkContext, SparkConf, LocalSparkContext, SparkContext} + + +class FlatmapIteratorSuite extends FunSuite with LocalSparkContext { + /* Tests the ability of Spark to deal with user provided iterators from flatMap + * calls, that may generate more data then available memory. In any + * memory based persistance Spark will unroll the iterator into an ArrayBuffer + * for caching, however in the case that the use defines DISK_ONLY persistance, + * the iterator will be fed directly to the serializer and written to disk. + * + * This also tests the ObjectOutputStream reset rate. When serializing using the + * Java serialization system, the serializer caches objects to prevent writing redundant + * data, however that stops GC of those objects. By calling 'reset' you flush that + * info from the serializer, and allow old objects to be GC'd + */ + test("Flatmap Iterator to Disk") { + val sconf = new SparkConf().setMaster("local").setAppName("iterator_to_disk_test") + sc = new SparkContext(sconf) + val expand_size = 100 + val data = sc.parallelize((1 to 5).toSeq). + flatMap( x => Stream.range(0, expand_size)) + var persisted = data.persist(StorageLevel.DISK_ONLY) + assert(persisted.count()===500) + assert(persisted.filter(_==1).count()===5) + } + + test("Flatmap Iterator to Memory") { + val sconf = new SparkConf().setMaster("local").setAppName("iterator_to_disk_test") + sc = new SparkContext(sconf) + val expand_size = 100 + val data = sc.parallelize((1 to 5).toSeq). + flatMap(x => Stream.range(0, expand_size)) + var persisted = data.persist(StorageLevel.MEMORY_ONLY) + assert(persisted.count()===500) + assert(persisted.filter(_==1).count()===5) + } + + test("Serializer Reset") { + val sconf = new SparkConf().setMaster("local").setAppName("serializer_reset_test") + .set("spark.serializer.objectStreamReset", "10") + sc = new SparkContext(sconf) + val expand_size = 500 + val data = sc.parallelize(Seq(1,2)). + flatMap(x => Stream.range(1, expand_size). + map(y => "%d: string test %d".format(y,x))) + var persisted = data.persist(StorageLevel.MEMORY_ONLY_SER) + assert(persisted.filter(_.startsWith("1:")).count()===2) + } + +} diff --git a/core/src/test/scala/org/apache/spark/ui/UISuite.scala b/core/src/test/scala/org/apache/spark/ui/UISuite.scala index 20ebb1897e6ba..30415814adbba 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISuite.scala @@ -24,6 +24,8 @@ import scala.util.{Failure, Success, Try} import org.eclipse.jetty.server.Server import org.scalatest.FunSuite +import org.apache.spark.SparkConf + class UISuite extends FunSuite { test("jetty port increases under contention") { val startPort = 4040 @@ -34,15 +36,17 @@ class UISuite extends FunSuite { case Failure(e) => // Either case server port is busy hence setup for test complete } - val (jettyServer1, boundPort1) = JettyUtils.startJettyServer("0.0.0.0", startPort, Seq()) - val (jettyServer2, boundPort2) = JettyUtils.startJettyServer("0.0.0.0", startPort, Seq()) + val (jettyServer1, boundPort1) = JettyUtils.startJettyServer("0.0.0.0", startPort, Seq(), + new SparkConf) + val (jettyServer2, boundPort2) = JettyUtils.startJettyServer("0.0.0.0", startPort, Seq(), + new SparkConf) // Allow some wiggle room in case ports on the machine are under contention assert(boundPort1 > startPort && boundPort1 < startPort + 10) assert(boundPort2 > boundPort1 && boundPort2 < boundPort1 + 10) } test("jetty binds to port 0 correctly") { - val (jettyServer, boundPort) = JettyUtils.startJettyServer("0.0.0.0", 0, Seq()) + val (jettyServer, boundPort) = JettyUtils.startJettyServer("0.0.0.0", 0, Seq(), new SparkConf) assert(jettyServer.getState === "STARTED") assert(boundPort != 0) Try {new ServerSocket(boundPort)} match { diff --git a/docker/README.md b/docker/README.md index bf59e77d111f9..40ba9c3065946 100644 --- a/docker/README.md +++ b/docker/README.md @@ -2,4 +2,6 @@ Spark docker files =========== Drawn from Matt Massie's docker files (https://github.com/massie/dockerfiles), -as well as some updates from Andre Schumacher (https://github.com/AndreSchumacher/docker). \ No newline at end of file +as well as some updates from Andre Schumacher (https://github.com/AndreSchumacher/docker). + +Tested with Docker version 0.8.1. diff --git a/docker/spark-test/master/default_cmd b/docker/spark-test/master/default_cmd index a5b1303c2ebdb..5a7da3446f6d2 100755 --- a/docker/spark-test/master/default_cmd +++ b/docker/spark-test/master/default_cmd @@ -19,4 +19,10 @@ IP=$(ip -o -4 addr list eth0 | perl -n -e 'if (m{inet\s([\d\.]+)\/\d+\s}xms) { print $1 }') echo "CONTAINER_IP=$IP" -/opt/spark/spark-class org.apache.spark.deploy.master.Master -i $IP +export SPARK_LOCAL_IP=$IP +export SPARK_PUBLIC_DNS=$IP + +# Avoid the default Docker behavior of mapping our IP address to an unreachable host name +umount /etc/hosts + +/opt/spark/bin/spark-class org.apache.spark.deploy.master.Master -i $IP diff --git a/docker/spark-test/worker/default_cmd b/docker/spark-test/worker/default_cmd index ab6336f70c1c6..31b06cb0eb047 100755 --- a/docker/spark-test/worker/default_cmd +++ b/docker/spark-test/worker/default_cmd @@ -19,4 +19,10 @@ IP=$(ip -o -4 addr list eth0 | perl -n -e 'if (m{inet\s([\d\.]+)\/\d+\s}xms) { print $1 }') echo "CONTAINER_IP=$IP" -/opt/spark/spark-class org.apache.spark.deploy.worker.Worker $1 +export SPARK_LOCAL_IP=$IP +export SPARK_PUBLIC_DNS=$IP + +# Avoid the default Docker behavior of mapping our IP address to an unreachable host name +umount /etc/hosts + +/opt/spark/bin/spark-class org.apache.spark.deploy.worker.Worker $1 diff --git a/docs/building-with-maven.md b/docs/building-with-maven.md index a982c4dbac7d4..d3bc34e68b240 100644 --- a/docs/building-with-maven.md +++ b/docs/building-with-maven.md @@ -56,7 +56,7 @@ Tests are run by default via the [ScalaTest Maven plugin](http://www.scalatest.o The ScalaTest plugin also supports running only a specific test suite as follows: - $ mvn -Dhadoop.version=... -Dsuites=spark.repl.ReplSuite test + $ mvn -Dhadoop.version=... -Dsuites=org.apache.spark.repl.ReplSuite test ## Continuous Compilation ## diff --git a/docs/configuration.md b/docs/configuration.md index dc5553f3da770..a006224d5080c 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -147,6 +147,34 @@ Apart from these, the following properties are also available, and may be useful How many stages the Spark UI remembers before garbage collecting. + + spark.ui.filters + None + + Comma separated list of filter class names to apply to the Spark web ui. The filter should be a + standard javax servlet Filter. Parameters to each filter can also be specified by setting a + java system property of spark.<class name of filter>.params='param1=value1,param2=value2' + (e.g.-Dspark.ui.filters=com.test.filter1 -Dspark.com.test.filter1.params='param1=foo,param2=testing') + + + + spark.ui.acls.enable + false + + Whether spark web ui acls should are enabled. If enabled, this checks to see if the user has + access permissions to view the web ui. See spark.ui.view.acls for more details. + Also note this requires the user to be known, if the user comes across as null no checks + are done. Filters can be used to authenticate and set the user. + + + + spark.ui.view.acls + Empty + + Comma separated list of users that have view access to the spark web ui. By default only the + user that started the Spark job has view access. + + spark.shuffle.compress true @@ -244,6 +272,17 @@ Apart from these, the following properties are also available, and may be useful exceeded" exception inside Kryo. Note that there will be one buffer per core on each worker. + + spark.serializer.objectStreamReset + 10000 + + When serializing using org.apache.spark.serializer.JavaSerializer, the serializer caches + objects to prevent writing redundant data, however that stops garbage collection of those + objects. By calling 'reset' you flush that info from the serializer, and allow old + objects to be collected. To turn off this periodic reset set it to a value of <= 0. + By default it will reset the serializer every 10,000 objects. + + spark.broadcast.factory org.apache.spark.broadcast.
HttpBroadcastFactory @@ -476,7 +515,7 @@ Apart from these, the following properties are also available, and may be useful the whole cluster by default.
Note: this setting needs to be configured in the standalone cluster master, not in individual applications; you can set it through SPARK_JAVA_OPTS in spark-env.sh. - + spark.files.overwrite @@ -485,6 +524,38 @@ Apart from these, the following properties are also available, and may be useful Whether to overwrite files added through SparkContext.addFile() when the target file exists and its contents do not match those of the source. + + spark.files.fetchTimeout + false + + Communication timeout to use when fetching files added through SparkContext.addFile() from + the driver. + + + + spark.authenticate + false + + Whether spark authenticates its internal connections. See spark.authenticate.secret if not + running on Yarn. + + + + spark.authenticate.secret + None + + Set the secret key used for Spark to authenticate between components. This needs to be set if + not running on Yarn and authentication is enabled. + + + + spark.core.connection.auth.wait.timeout + 30 + + Number of seconds for the connection to wait for authentication to occur before timing + out and giving up. + + ## Viewing Spark Properties diff --git a/docs/index.md b/docs/index.md index 4eb297df39144..c4f4d79edbc6c 100644 --- a/docs/index.md +++ b/docs/index.md @@ -103,6 +103,7 @@ For this version of Spark (0.8.1) Hadoop 2.2.x (or newer) users will have to bui * [Configuration](configuration.html): customize Spark via its configuration system * [Tuning Guide](tuning.html): best practices to optimize performance and memory use +* [Security](security.html): Spark security support * [Hardware Provisioning](hardware-provisioning.html): recommendations for cluster hardware * [Job Scheduling](job-scheduling.html): scheduling resources across and within Spark applications * [Building Spark with Maven](building-with-maven.html): build Spark using the Maven system diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index ee1d892a3b630..b17929542c531 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -29,7 +29,7 @@ If you want to test out the YARN deployment mode, you can use the current Spark # Configuration -Most of the configs are the same for Spark on YARN as other deploys. See the Configuration page for more information on those. These are configs that are specific to SPARK on YARN. +Most of the configs are the same for Spark on YARN as for other deployment modes. See the Configuration page for more information on those. These are configs that are specific to Spark on YARN. Environment variables: @@ -41,28 +41,30 @@ System Properties: * `spark.yarn.submit.file.replication`, the HDFS replication level for the files uploaded into HDFS for the application. These include things like the spark jar, the app jar, and any distributed cache files/archives. * `spark.yarn.preserve.staging.files`, set to true to preserve the staged files(spark jar, app jar, distributed cache files) at the end of the job rather then delete them. * `spark.yarn.scheduler.heartbeat.interval-ms`, the interval in ms in which the Spark application master heartbeats into the YARN ResourceManager. Default is 5 seconds. -* `spark.yarn.max.worker.failures`, the maximum number of worker failures before failing the application. Default is the number of workers requested times 2 with minimum of 3. +* `spark.yarn.max.worker.failures`, the maximum number of executor failures before failing the application. Default is the number of executors requested times 2 with minimum of 3. # Launching Spark on YARN -Ensure that HADOOP_CONF_DIR or YARN_CONF_DIR points to the directory which contains the (client side) configuration files for the hadoop cluster. -This would be used to connect to the cluster, write to the dfs and submit jobs to the resource manager. +Ensure that HADOOP_CONF_DIR or YARN_CONF_DIR points to the directory which contains the (client side) configuration files for the Hadoop cluster. +These configs are used to connect to the cluster, write to the dfs, and connect to the YARN ResourceManager. -There are two scheduler mode that can be used to launch spark application on YARN. +There are two scheduler modes that can be used to launch Spark applications on YARN. In yarn-cluster mode, the Spark driver runs inside an application master process which is managed by YARN on the cluster, and the client can go away after initiating the application. In yarn-client mode, the driver runs in the client process, and the application master is only used for requesting resources from YARN. -## Launch spark application by YARN Client with yarn-standalone mode. +Unlike in Spark standalone and Mesos mode, in which the master's address is specified in the "master" parameter, in YARN mode the ResourceManager's address is picked up from the Hadoop configuration. Thus, the master parameter is simply "yarn-client" or "yarn-cluster". -The command to launch the YARN Client is as follows: +## Launching a Spark application with yarn-cluster mode. + +The command to launch the Spark application on the cluster is as follows: SPARK_JAR= ./bin/spark-class org.apache.spark.deploy.yarn.Client \ --jar \ --class \ --args \ - --num-workers \ + --num-workers \ --master-class --master-memory \ - --worker-memory \ - --worker-cores \ + --worker-memory \ + --worker-cores \ --name \ --queue \ --addJars \ @@ -82,35 +84,30 @@ For example: ./bin/spark-class org.apache.spark.deploy.yarn.Client \ --jar examples/target/scala-{{site.SCALA_BINARY_VERSION}}/spark-examples-assembly-{{site.SPARK_VERSION}}.jar \ --class org.apache.spark.examples.SparkPi \ - --args yarn-standalone \ + --args yarn-cluster \ --num-workers 3 \ --master-memory 4g \ --worker-memory 2g \ --worker-cores 1 - # Examine the output (replace $YARN_APP_ID in the following with the "application identifier" output by the previous command) - # (Note: YARN_APP_LOGS_DIR is usually /tmp/logs or $HADOOP_HOME/logs/userlogs depending on the Hadoop version.) - $ cat $YARN_APP_LOGS_DIR/$YARN_APP_ID/container*_000001/stdout - Pi is roughly 3.13794 - -The above starts a YARN Client programs which start the default Application Master. Then SparkPi will be run as a child thread of Application Master, YARN Client will periodically polls the Application Master for status updates and displays them in the console. The client will exit once your application has finished running. +The above starts a YARN client program which starts the default Application Master. Then SparkPi will be run as a child thread of Application Master. The client will periodically poll the Application Master for status updates and display them in the console. The client will exit once your application has finished running. Refer to the "Viewing Logs" section below for how to see driver and executor logs. -With this mode, your application is actually run on the remote machine where the Application Master is run upon. Thus application that involve local interaction will not work well, e.g. spark-shell. +Because the application is run on a remote machine where the Application Master is running, applications that involve local interaction, such as spark-shell, will not work. -## Launch spark application with yarn-client mode. +## Launching a Spark application with yarn-client mode. -With yarn-client mode, the application will be launched locally. Just like running application or spark-shell on Local / Mesos / Standalone mode. The launch method is also the similar with them, just make sure that when you need to specify a master url, use "yarn-client" instead. And you also need to export the env value for SPARK_JAR. +With yarn-client mode, the application will be launched locally, just like running an application or spark-shell on Local / Mesos / Standalone client mode. The launch method is also the same, just make sure to specify the master URL as "yarn-client". You also need to export the env value for SPARK_JAR. Configuration in yarn-client mode: -In order to tune worker core/number/memory etc. You need to export environment variables or add them to the spark configuration file (./conf/spark_env.sh). The following are the list of options. +In order to tune worker cores/number/memory etc., you need to export environment variables or add them to the spark configuration file (./conf/spark_env.sh). The following are the list of options. -* `SPARK_WORKER_INSTANCES`, Number of workers to start (Default: 2) -* `SPARK_WORKER_CORES`, Number of cores for the workers (Default: 1). -* `SPARK_WORKER_MEMORY`, Memory per Worker (e.g. 1000M, 2G) (Default: 1G) +* `SPARK_WORKER_INSTANCES`, Number of executors to start (Default: 2) +* `SPARK_WORKER_CORES`, Number of cores per executor (Default: 1). +* `SPARK_WORKER_MEMORY`, Memory per executor (e.g. 1000M, 2G) (Default: 1G) * `SPARK_MASTER_MEMORY`, Memory for Master (e.g. 1000M, 2G) (Default: 512 Mb) * `SPARK_YARN_APP_NAME`, The name of your application (Default: Spark) -* `SPARK_YARN_QUEUE`, The hadoop queue to use for allocation requests (Default: 'default') +* `SPARK_YARN_QUEUE`, The YARN queue to use for allocation requests (Default: 'default') * `SPARK_YARN_DIST_FILES`, Comma separated list of files to be distributed with the job. * `SPARK_YARN_DIST_ARCHIVES`, Comma separated list of archives to be distributed with the job. @@ -125,13 +122,23 @@ or MASTER=yarn-client ./bin/spark-shell +## Viewing logs + +In YARN terminology, executors and application masters run inside "containers". YARN has two modes for handling container logs after an application has completed. If log aggregation is turned on (with the yarn.log-aggregation-enable config), container logs are copied to HDFS and deleted on the local machine. These logs can be viewed from anywhere on the cluster with the "yarn logs" command. + + yarn logs -applicationId + +will print out the contents of all log files from all containers from the given application. + +When log aggregation isn't turned on, logs are retained locally on each machine under YARN_APP_LOGS_DIR, which is usually configured to /tmp/logs or $HADOOP_HOME/logs/userlogs depending on the Hadoop version and installation. Viewing logs for a container requires going to the host that contains them and looking in this directory. Subdirectories organize log files by application ID and container ID. + # Building Spark for Hadoop/YARN 2.2.x -See [Building Spark with Maven](building-with-maven.html) for instructions on how to build Spark using the Maven process. +See [Building Spark with Maven](building-with-maven.html) for instructions on how to build Spark using Maven. -# Important Notes +# Important notes - Before Hadoop 2.2, YARN does not support cores in container resource requests. Thus, when running against an earlier version, the numbers of cores given via command line arguments cannot be passed to YARN. Whether core requests are honored in scheduling decisions depends on which scheduler is in use and how it is configured. -- The local directories used for spark will be the local directories configured for YARN (Hadoop Yarn config yarn.nodemanager.local-dirs). If the user specifies spark.local.dir, it will be ignored. -- The --files and --archives options support specifying file names with the # similar to Hadoop. For example you can specify: --files localtest.txt#appSees.txt and this will upload the file you have locally named localtest.txt into HDFS but this will be linked to by the name appSees.txt and your application should use the name as appSees.txt to reference it when running on YARN. +- The local directories used by Spark executors will be the local directories configured for YARN (Hadoop YARN config yarn.nodemanager.local-dirs). If the user specifies spark.local.dir, it will be ignored. +- The --files and --archives options support specifying file names with the # similar to Hadoop. For example you can specify: --files localtest.txt#appSees.txt and this will upload the file you have locally named localtest.txt into HDFS but this will be linked to by the name appSees.txt, and your application should use the name as appSees.txt to reference it when running on YARN. - The --addJars option allows the SparkContext.addJar function to work if you are using it with local files. It does not need to be used if you are using it with HDFS, HTTP, HTTPS, or FTP files. diff --git a/docs/security.md b/docs/security.md new file mode 100644 index 0000000000000..9e4218fbcfe7d --- /dev/null +++ b/docs/security.md @@ -0,0 +1,18 @@ +--- +layout: global +title: Spark Security +--- + +Spark currently supports authentication via a shared secret. Authentication can be configured to be on via the `spark.authenticate` configuration parameter. This parameter controls whether the Spark communication protocols do authentication using the shared secret. This authentication is a basic handshake to make sure both sides have the same shared secret and are allowed to communicate. If the shared secret is not identical they will not be allowed to communicate. + +The Spark UI can also be secured by using javax servlet filters. A user may want to secure the UI if it has data that other users should not be allowed to see. The javax servlet filter specified by the user can authenticate the user and then once the user is logged in, Spark can compare that user versus the view acls to make sure they are authorized to view the UI. The configs 'spark.ui.acls.enable' and 'spark.ui.view.acls' control the behavior of the acls. Note that the person who started the application always has view access to the UI. + +For Spark on Yarn deployments, configuring `spark.authenticate` to true will automatically handle generating and distributing the shared secret. Each application will use a unique shared secret. The Spark UI uses the standard YARN web application proxy mechanism and will authenticate via any installed Hadoop filters. If an authentication filter is enabled, the acls controls can be used by control which users can via the Spark UI. + +For other types of Spark deployments, the spark config `spark.authenticate.secret` should be configured on each of the nodes. This secret will be used by all the Master/Workers and applications. The UI can be secured using a javax servlet filter installed via `spark.ui.filters`. If an authentication filter is enabled, the acls controls can be used by control which users can via the Spark UI. + +IMPORTANT NOTE: The NettyBlockFetcherIterator is not secured so do not use netty for the shuffle is running with authentication on. + +See [Spark Configuration](configuration.html) for more details on the security configs. + +See org.apache.spark.SecurityManager for implementation details about security. diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 2a56cf07d0cfc..f9904d45013f6 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -539,7 +539,7 @@ common ones are as follows. updateStateByKey(func) Return a new "state" DStream where the state for each key is updated by applying the given function on the previous state of the key and the new values for the key. This can be - used to maintain arbitrary state data for each ket. + used to maintain arbitrary state data for each key. diff --git a/docs/tuning.md b/docs/tuning.md index 26ff1325bb59c..093df3187a789 100644 --- a/docs/tuning.md +++ b/docs/tuning.md @@ -163,7 +163,7 @@ their work directories), *not* on your driver program. **Cache Size Tuning** One important configuration parameter for GC is the amount of memory that should be used for caching RDDs. -By default, Spark uses 60% of the configured executor memory (`spark.executor.memory` or `SPARK_MEM`) to +By default, Spark uses 60% of the configured executor memory (`spark.executor.memory`) to cache RDDs. This means that 40% of memory is available for any objects created during task execution. In case your tasks slow down and you find that your JVM is garbage-collecting frequently or running out of diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 25e85381896b0..d8840c94ac17c 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -398,15 +398,13 @@ def get_existing_cluster(conn, opts, cluster_name, die_on_error=True): if any((master_nodes, slave_nodes)): print ("Found %d master(s), %d slaves" % (len(master_nodes), len(slave_nodes))) - if (master_nodes != [] and slave_nodes != []) or not die_on_error: + if master_nodes != [] or not die_on_error: return (master_nodes, slave_nodes) else: if master_nodes == [] and slave_nodes != []: - print "ERROR: Could not find master in group " + cluster_name + "-master" - elif master_nodes != [] and slave_nodes == []: - print "ERROR: Could not find slaves in group " + cluster_name + "-slaves" + print >> sys.stderr, "ERROR: Could not find master in group " + cluster_name + "-master" else: - print "ERROR: Could not find any existing cluster" + print >> sys.stderr, "ERROR: Could not find any existing cluster" sys.exit(1) @@ -680,6 +678,9 @@ def real_main(): opts.zone = random.choice(conn.get_all_zones()).name if action == "launch": + if opts.slaves <= 0: + print >> sys.stderr, "ERROR: You have to start at least 1 slave" + sys.exit(1) if opts.resume: (master_nodes, slave_nodes) = get_existing_cluster( conn, opts, cluster_name) diff --git a/examples/pom.xml b/examples/pom.xml index 3aba343f4cf50..9f0e2d0b875b8 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -37,10 +37,10 @@ a Hadoop 0.23.X issue --> yarn-alpha - - org.apache.avro - avro - + + org.apache.avro + avro + diff --git a/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala b/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala new file mode 100644 index 0000000000000..ee283ce6abac2 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples + +import java.nio.ByteBuffer +import scala.collection.JavaConversions._ +import scala.collection.mutable.ListBuffer +import scala.collection.immutable.Map +import org.apache.cassandra.hadoop.ConfigHelper +import org.apache.cassandra.hadoop.cql3.CqlPagingInputFormat +import org.apache.cassandra.hadoop.cql3.CqlConfigHelper +import org.apache.cassandra.hadoop.cql3.CqlOutputFormat +import org.apache.cassandra.utils.ByteBufferUtil +import org.apache.hadoop.mapreduce.Job +import org.apache.spark.SparkContext +import org.apache.spark.SparkContext._ + +/* + Need to create following keyspace and column family in cassandra before running this example + Start CQL shell using ./bin/cqlsh and execute following commands + CREATE KEYSPACE retail WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1}; + use retail; + CREATE TABLE salecount (prod_id text, sale_count int, PRIMARY KEY (prod_id)); + CREATE TABLE ordercf (user_id text, + time timestamp, + prod_id text, + quantity int, + PRIMARY KEY (user_id, time)); + INSERT INTO ordercf (user_id, + time, + prod_id, + quantity) VALUES ('bob', 1385983646000, 'iphone', 1); + INSERT INTO ordercf (user_id, + time, + prod_id, + quantity) VALUES ('tom', 1385983647000, 'samsung', 4); + INSERT INTO ordercf (user_id, + time, + prod_id, + quantity) VALUES ('dora', 1385983648000, 'nokia', 2); + INSERT INTO ordercf (user_id, + time, + prod_id, + quantity) VALUES ('charlie', 1385983649000, 'iphone', 2); +*/ + +/** + * This example demonstrates how to read and write to cassandra column family created using CQL3 + * using Spark. + * Parameters : + * Usage: ./bin/run-example org.apache.spark.examples.CassandraCQLTest local[2] localhost 9160 + * + */ +object CassandraCQLTest { + + def main(args: Array[String]) { + val sc = new SparkContext(args(0), + "CQLTestApp", + System.getenv("SPARK_HOME"), + SparkContext.jarOfClass(this.getClass)) + val cHost: String = args(1) + val cPort: String = args(2) + val KeySpace = "retail" + val InputColumnFamily = "ordercf" + val OutputColumnFamily = "salecount" + + val job = new Job() + job.setInputFormatClass(classOf[CqlPagingInputFormat]) + ConfigHelper.setInputInitialAddress(job.getConfiguration(), cHost) + ConfigHelper.setInputRpcPort(job.getConfiguration(), cPort) + ConfigHelper.setInputColumnFamily(job.getConfiguration(), KeySpace, InputColumnFamily) + ConfigHelper.setInputPartitioner(job.getConfiguration(), "Murmur3Partitioner") + CqlConfigHelper.setInputCQLPageRowSize(job.getConfiguration(), "3") + + /** CqlConfigHelper.setInputWhereClauses(job.getConfiguration(), "user_id='bob'") */ + + /** An UPDATE writes one or more columns to a record in a Cassandra column family */ + val query = "UPDATE " + KeySpace + "." + OutputColumnFamily + " SET sale_count = ? " + CqlConfigHelper.setOutputCql(job.getConfiguration(), query) + + job.setOutputFormatClass(classOf[CqlOutputFormat]) + ConfigHelper.setOutputColumnFamily(job.getConfiguration(), KeySpace, OutputColumnFamily) + ConfigHelper.setOutputInitialAddress(job.getConfiguration(), cHost) + ConfigHelper.setOutputRpcPort(job.getConfiguration(), cPort) + ConfigHelper.setOutputPartitioner(job.getConfiguration(), "Murmur3Partitioner") + + val casRdd = sc.newAPIHadoopRDD(job.getConfiguration(), + classOf[CqlPagingInputFormat], + classOf[java.util.Map[String,ByteBuffer]], + classOf[java.util.Map[String,ByteBuffer]]) + + println("Count: " + casRdd.count) + val productSaleRDD = casRdd.map { + case (key, value) => { + (ByteBufferUtil.string(value.get("prod_id")), ByteBufferUtil.toInt(value.get("quantity"))) + } + } + val aggregatedRDD = productSaleRDD.reduceByKey(_ + _) + aggregatedRDD.collect().foreach { + case (productId, saleCount) => println(productId + ":" + saleCount) + } + + val casoutputCF = aggregatedRDD.map { + case (productId, saleCount) => { + val outColFamKey = Map("prod_id" -> ByteBufferUtil.bytes(productId)) + val outKey: java.util.Map[String, ByteBuffer] = outColFamKey + var outColFamVal = new ListBuffer[ByteBuffer] + outColFamVal += ByteBufferUtil.bytes(saleCount) + val outVal: java.util.List[ByteBuffer] = outColFamVal + (outKey, outVal) + } + } + + casoutputCF.saveAsNewAPIHadoopFile( + KeySpace, + classOf[java.util.Map[String, ByteBuffer]], + classOf[java.util.List[ByteBuffer]], + classOf[CqlOutputFormat], + job.getConfiguration() + ) + } +} diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/ActorWordCount.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/ActorWordCount.scala index 3d7b390724e77..62d3a52615584 100644 --- a/examples/src/main/scala/org/apache/spark/streaming/examples/ActorWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/streaming/examples/ActorWordCount.scala @@ -23,7 +23,7 @@ import scala.util.Random import akka.actor.{Actor, ActorRef, Props, actorRef2Scala} -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SecurityManager} import org.apache.spark.streaming.{Seconds, StreamingContext} import org.apache.spark.streaming.StreamingContext.toPairDStreamFunctions import org.apache.spark.streaming.receivers.Receiver @@ -112,8 +112,9 @@ object FeederActor { } val Seq(host, port) = args.toSeq - - val actorSystem = AkkaUtils.createActorSystem("test", host, port.toInt, conf = new SparkConf)._1 + val conf = new SparkConf + val actorSystem = AkkaUtils.createActorSystem("test", host, port.toInt, conf = conf, + securityManager = new SecurityManager(conf))._1 val feeder = actorSystem.actorOf(Props[FeederActor], "FeederActor") println("Feeder started as:" + feeder) diff --git a/external/flume/pom.xml b/external/flume/pom.xml index 8783aea3e4a5b..f21963531574b 100644 --- a/external/flume/pom.xml +++ b/external/flume/pom.xml @@ -37,10 +37,10 @@ a Hadoop 0.23.X issue --> yarn-alpha - - org.apache.avro - avro - + + org.apache.avro + avro + diff --git a/external/kafka/pom.xml b/external/kafka/pom.xml index 79dc38f9844a0..343e1fabd823f 100644 --- a/external/kafka/pom.xml +++ b/external/kafka/pom.xml @@ -37,10 +37,10 @@ a Hadoop 0.23.X issue --> yarn-alpha - - org.apache.avro - avro - + + org.apache.avro + avro + diff --git a/external/twitter/pom.xml b/external/twitter/pom.xml index 37bb4fad64f68..398b9f4fbaa7d 100644 --- a/external/twitter/pom.xml +++ b/external/twitter/pom.xml @@ -37,10 +37,10 @@ a Hadoop 0.23.X issue --> yarn-alpha - - org.apache.avro - avro - + + org.apache.avro + avro + diff --git a/external/zeromq/pom.xml b/external/zeromq/pom.xml index 65ec0e26da881..77e957f404645 100644 --- a/external/zeromq/pom.xml +++ b/external/zeromq/pom.xml @@ -37,10 +37,10 @@ a Hadoop 0.23.X issue --> yarn-alpha - - org.apache.avro - avro - + + org.apache.avro + avro + diff --git a/graphx/pom.xml b/graphx/pom.xml index 5b54dd27efb44..894a7c2641e39 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -37,10 +37,10 @@ a Hadoop 0.23.X issue --> yarn-alpha - - org.apache.avro - avro - + + org.apache.avro + avro + diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala index d1528e2f07cf2..014a7335f85cc 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala @@ -23,8 +23,8 @@ import scala.collection.mutable.HashSet import org.apache.spark.util.Utils -import org.objectweb.asm.{ClassReader, ClassVisitor, MethodVisitor} -import org.objectweb.asm.Opcodes._ +import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.{ClassReader, ClassVisitor, MethodVisitor} +import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.Opcodes._ /** diff --git a/mllib/pom.xml b/mllib/pom.xml index 760a2a85d5ffa..9b65cb4b4ce3f 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -37,10 +37,10 @@ a Hadoop 0.23.X issue --> yarn-alpha - - org.apache.avro - avro - + + org.apache.avro + avro + diff --git a/pom.xml b/pom.xml index c59fada5cd4a0..f0c877dcfe7b2 100644 --- a/pom.xml +++ b/pom.xml @@ -155,6 +155,21 @@ + + org.eclipse.jetty + jetty-util + 7.6.8.v20121106 + + + org.eclipse.jetty + jetty-security + 7.6.8.v20121106 + + + org.eclipse.jetty + jetty-plus + 7.6.8.v20121106 + org.eclipse.jetty jetty-server @@ -206,11 +221,6 @@ snappy-java 1.0.5 - - org.ow2.asm - asm - 4.0 - com.clearspring.analytics stream @@ -230,11 +240,31 @@ com.twitter chill_${scala.binary.version} 0.3.1 + + + org.ow2.asm + asm + + + org.ow2.asm + asm-commons + + com.twitter chill-java 0.3.1 + + + org.ow2.asm + asm + + + org.ow2.asm + asm-commons + + ${akka.group} @@ -295,6 +325,11 @@ mesos ${mesos.version} + + commons-net + commons-net + 2.2 + io.netty netty-all @@ -415,6 +450,10 @@ asm asm + + org.ow2.asm + asm + org.jboss.netty netty @@ -454,6 +493,10 @@ asm asm + + org.ow2.asm + asm + org.jboss.netty netty @@ -469,6 +512,10 @@ asm asm + + org.ow2.asm + asm + org.jboss.netty netty @@ -485,6 +532,10 @@ asm asm + + org.ow2.asm + asm + org.jboss.netty netty diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index aa1784897566b..8fa220c413291 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -226,6 +226,9 @@ object SparkBuild extends Build { libraryDependencies ++= Seq( "io.netty" % "netty-all" % "4.0.17.Final", "org.eclipse.jetty" % "jetty-server" % "7.6.8.v20121106", + "org.eclipse.jetty" % "jetty-util" % "7.6.8.v20121106", + "org.eclipse.jetty" % "jetty-plus" % "7.6.8.v20121106", + "org.eclipse.jetty" % "jetty-security" % "7.6.8.v20121106", /** Workaround for SPARK-959. Dependency used by org.eclipse.jetty. Fixed in ivy 2.3.0. */ "org.eclipse.jetty.orbit" % "javax.servlet" % "2.5.0.v201103041518" artifacts Artifact("javax.servlet", "jar", "jar"), "org.scalatest" %% "scalatest" % "1.9.1" % "test", @@ -254,7 +257,8 @@ object SparkBuild extends Build { val slf4jVersion = "1.7.5" val excludeNetty = ExclusionRule(organization = "org.jboss.netty") - val excludeAsm = ExclusionRule(organization = "asm") + val excludeAsm = ExclusionRule(organization = "org.ow2.asm") + val excludeOldAsm = ExclusionRule(organization = "asm") val excludeCommonsLogging = ExclusionRule(organization = "commons-logging") val excludeSLF4J = ExclusionRule(organization = "org.slf4j") val excludeScalap = ExclusionRule(organization = "org.scala-lang", artifact = "scalap") @@ -277,7 +281,6 @@ object SparkBuild extends Build { "commons-daemon" % "commons-daemon" % "1.0.10", // workaround for bug HADOOP-9407 "com.ning" % "compress-lzf" % "1.0.0", "org.xerial.snappy" % "snappy-java" % "1.0.5", - "org.ow2.asm" % "asm" % "4.0", "org.spark-project.akka" %% "akka-remote" % "2.2.3-shaded-protobuf" excludeAll(excludeNetty), "org.spark-project.akka" %% "akka-slf4j" % "2.2.3-shaded-protobuf" excludeAll(excludeNetty), "org.spark-project.akka" %% "akka-testkit" % "2.2.3-shaded-protobuf" % "test", @@ -285,17 +288,18 @@ object SparkBuild extends Build { "it.unimi.dsi" % "fastutil" % "6.4.4", "colt" % "colt" % "1.2.0", "org.apache.mesos" % "mesos" % "0.13.0", + "commons-net" % "commons-net" % "2.2", "net.java.dev.jets3t" % "jets3t" % "0.7.1" excludeAll(excludeCommonsLogging), "org.apache.derby" % "derby" % "10.4.2.0" % "test", - "org.apache.hadoop" % hadoopClient % hadoopVersion excludeAll(excludeNetty, excludeAsm, excludeCommonsLogging, excludeSLF4J), + "org.apache.hadoop" % hadoopClient % hadoopVersion excludeAll(excludeNetty, excludeAsm, excludeCommonsLogging, excludeSLF4J, excludeOldAsm), "org.apache.curator" % "curator-recipes" % "2.4.0" excludeAll(excludeNetty), "com.codahale.metrics" % "metrics-core" % "3.0.0", "com.codahale.metrics" % "metrics-jvm" % "3.0.0", "com.codahale.metrics" % "metrics-json" % "3.0.0", "com.codahale.metrics" % "metrics-ganglia" % "3.0.0", "com.codahale.metrics" % "metrics-graphite" % "3.0.0", - "com.twitter" %% "chill" % "0.3.1", - "com.twitter" % "chill-java" % "0.3.1", + "com.twitter" %% "chill" % "0.3.1" excludeAll(excludeAsm), + "com.twitter" % "chill-java" % "0.3.1" excludeAll(excludeAsm), "com.clearspring.analytics" % "stream" % "2.5.1" ), libraryDependencies ++= maybeAvro @@ -316,7 +320,7 @@ object SparkBuild extends Build { name := "spark-examples", libraryDependencies ++= Seq( "com.twitter" %% "algebird-core" % "0.1.11", - "org.apache.hbase" % "hbase" % HBASE_VERSION excludeAll(excludeNetty, excludeAsm, excludeCommonsLogging), + "org.apache.hbase" % "hbase" % HBASE_VERSION excludeAll(excludeNetty, excludeAsm, excludeOldAsm, excludeCommonsLogging), "org.apache.cassandra" % "cassandra-all" % "1.2.6" exclude("com.google.guava", "guava") exclude("com.googlecode.concurrentlinkedhashmap", "concurrentlinkedhashmap-lru") @@ -393,10 +397,10 @@ object SparkBuild extends Build { def yarnEnabledSettings = Seq( libraryDependencies ++= Seq( // Exclude rule required for all ? - "org.apache.hadoop" % hadoopClient % hadoopVersion excludeAll(excludeNetty, excludeAsm), - "org.apache.hadoop" % "hadoop-yarn-api" % hadoopVersion excludeAll(excludeNetty, excludeAsm), - "org.apache.hadoop" % "hadoop-yarn-common" % hadoopVersion excludeAll(excludeNetty, excludeAsm), - "org.apache.hadoop" % "hadoop-yarn-client" % hadoopVersion excludeAll(excludeNetty, excludeAsm) + "org.apache.hadoop" % hadoopClient % hadoopVersion excludeAll(excludeNetty, excludeAsm, excludeOldAsm), + "org.apache.hadoop" % "hadoop-yarn-api" % hadoopVersion excludeAll(excludeNetty, excludeAsm, excludeOldAsm), + "org.apache.hadoop" % "hadoop-yarn-common" % hadoopVersion excludeAll(excludeNetty, excludeAsm, excludeOldAsm), + "org.apache.hadoop" % "hadoop-yarn-client" % hadoopVersion excludeAll(excludeNetty, excludeAsm, excludeOldAsm) ) ) diff --git a/project/plugins.sbt b/project/plugins.sbt index 914f2e05a402a..32bc044a93221 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -19,3 +19,4 @@ addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.7.4") addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "0.4.0") +addSbtPlugin("com.alpinenow" % "junit_xml_listener" % "0.5.0") diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 93faa2e3857ed..c9f42d3aacb58 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -372,6 +372,37 @@ def _getJavaStorageLevel(self, storageLevel): return newStorageLevel(storageLevel.useDisk, storageLevel.useMemory, storageLevel.deserialized, storageLevel.replication) + def setJobGroup(self, groupId, description): + """ + Assigns a group ID to all the jobs started by this thread until the group ID is set to a + different value or cleared. + + Often, a unit of execution in an application consists of multiple Spark actions or jobs. + Application programmers can use this method to group all those jobs together and give a + group description. Once set, the Spark web UI will associate such jobs with this group. + """ + self._jsc.setJobGroup(groupId, description) + + def setLocalProperty(self, key, value): + """ + Set a local property that affects jobs submitted from this thread, such as the + Spark fair scheduler pool. + """ + self._jsc.setLocalProperty(key, value) + + def getLocalProperty(self, key): + """ + Get a local property set in this thread, or null if it is missing. See + L{setLocalProperty} + """ + return self._jsc.getLocalProperty(key) + + def sparkUser(self): + """ + Get SPARK_USER for user who is running SparkContext. + """ + return self._jsc.sc().sparkUser() + def _test(): import atexit import doctest diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index c15add5237507..6a16756e0576d 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -29,7 +29,7 @@ def launch_gateway(): # Launch the Py4j gateway using Spark's run command so that we pick up the - # proper classpath and SPARK_MEM settings from spark-env.sh + # proper classpath and settings from spark-env.sh on_windows = platform.system() == "Windows" script = "./bin/spark-class.cmd" if on_windows else "./bin/spark-class" command = [os.path.join(SPARK_HOME, script), "py4j.GatewayServer", diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index be23f87f5ed2d..e72f57d9d1ab0 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -95,6 +95,13 @@ def __init__(self, jrdd, ctx, jrdd_deserializer): self.is_checkpointed = False self.ctx = ctx self._jrdd_deserializer = jrdd_deserializer + self._id = jrdd.id() + + def id(self): + """ + A unique ID for this RDD (within its SparkContext). + """ + return self._id def __repr__(self): return self._jrdd.toString() @@ -319,6 +326,23 @@ def union(self, other): return RDD(self_copy._jrdd.union(other_copy._jrdd), self.ctx, self.ctx.serializer) + def intersection(self, other): + """ + Return the intersection of this RDD and another one. The output will not + contain any duplicate elements, even if the input RDDs did. + + Note that this method performs a shuffle internally. + + >>> rdd1 = sc.parallelize([1, 10, 2, 3, 4, 5]) + >>> rdd2 = sc.parallelize([1, 6, 2, 3, 7, 8]) + >>> rdd1.intersection(rdd2).collect() + [1, 2, 3] + """ + return self.map(lambda v: (v, None)) \ + .cogroup(other.map(lambda v: (v, None))) \ + .filter(lambda x: (len(x[1][0]) != 0) and (len(x[1][1]) != 0)) \ + .keys() + def _reserialize(self): if self._jrdd_deserializer == self.ctx.serializer: return self diff --git a/repl/pom.xml b/repl/pom.xml index aa01a1760285a..fc49c8b811316 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -37,10 +37,10 @@ yarn-alpha - - org.apache.avro - avro - + + org.apache.avro + avro + diff --git a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala index e3bcf7f30ac8d..ee972887feda6 100644 --- a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala +++ b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala @@ -18,14 +18,18 @@ package org.apache.spark.repl import java.io.{ByteArrayOutputStream, InputStream} -import java.net.{URI, URL, URLClassLoader, URLEncoder} +import java.net.{URI, URL, URLEncoder} import java.util.concurrent.{Executors, ExecutorService} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} -import org.objectweb.asm._ -import org.objectweb.asm.Opcodes._ +import org.apache.spark.SparkEnv +import org.apache.spark.util.Utils + + +import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm._ +import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.Opcodes._ /** @@ -53,7 +57,13 @@ extends ClassLoader(parent) { if (fileSystem != null) { fileSystem.open(new Path(directory, pathInDirectory)) } else { - new URL(classUri + "/" + urlEncode(pathInDirectory)).openStream() + if (SparkEnv.get.securityManager.isAuthenticationEnabled()) { + val uri = new URI(classUri + "/" + urlEncode(pathInDirectory)) + val newuri = Utils.constructURIForAuthentication(uri, SparkEnv.get.securityManager) + newuri.toURL().openStream() + } else { + new URL(classUri + "/" + urlEncode(pathInDirectory)).openStream() + } } } val bytes = readAndTransformClass(name, inputStream) diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala index f52ebe4a159f1..9b1da195002c2 100644 --- a/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ b/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -881,6 +881,8 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, }) def process(settings: Settings): Boolean = savingContextLoader { + if (getMaster() == "yarn-client") System.setProperty("SPARK_YARN_MODE", "true") + this.settings = settings createInterpreter() @@ -939,16 +941,9 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, def createSparkContext(): SparkContext = { val execUri = System.getenv("SPARK_EXECUTOR_URI") - val master = this.master match { - case Some(m) => m - case None => { - val prop = System.getenv("MASTER") - if (prop != null) prop else "local" - } - } val jars = SparkILoop.getAddedJars.map(new java.io.File(_).getAbsolutePath) val conf = new SparkConf() - .setMaster(master) + .setMaster(getMaster()) .setAppName("Spark shell") .setJars(jars) .set("spark.repl.class.uri", intp.classServer.uri) @@ -963,6 +958,17 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, sparkContext } + private def getMaster(): String = { + val master = this.master match { + case Some(m) => m + case None => { + val prop = System.getenv("MASTER") + if (prop != null) prop else "local" + } + } + master + } + /** process command-line arguments and do as they request */ def process(args: Array[String]): Boolean = { val command = new SparkCommandLine(args.toList, msg => echo(msg)) diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala index 1d73d0b6993a8..90a96ad38381e 100644 --- a/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala +++ b/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala @@ -36,7 +36,7 @@ import scala.tools.reflect.StdRuntimeTags._ import scala.util.control.ControlThrowable import util.stackTraceString -import org.apache.spark.{HttpServer, SparkConf, Logging} +import org.apache.spark.{Logging, HttpServer, SecurityManager, SparkConf} import org.apache.spark.util.Utils // /** directory to save .class files to */ @@ -83,15 +83,17 @@ import org.apache.spark.util.Utils * @author Moez A. Abdel-Gawad * @author Lex Spoon */ - class SparkIMain(initialSettings: Settings, val out: JPrintWriter) extends SparkImports with Logging { + class SparkIMain(initialSettings: Settings, val out: JPrintWriter) + extends SparkImports with Logging { imain => - val SPARK_DEBUG_REPL: Boolean = (System.getenv("SPARK_DEBUG_REPL") == "1") + val conf = new SparkConf() + val SPARK_DEBUG_REPL: Boolean = (System.getenv("SPARK_DEBUG_REPL") == "1") /** Local directory to save .class files too */ val outputDir = { val tmp = System.getProperty("java.io.tmpdir") - val rootDir = new SparkConf().get("spark.repl.classdir", tmp) + val rootDir = conf.get("spark.repl.classdir", tmp) Utils.createTempDir(rootDir) } if (SPARK_DEBUG_REPL) { @@ -99,7 +101,8 @@ import org.apache.spark.util.Utils } val virtualDirectory = new PlainFile(outputDir) // "directory" for classfiles - val classServer = new HttpServer(outputDir) /** Jetty server that will serve our classes to worker nodes */ + val classServer = new HttpServer(outputDir, + new SecurityManager(conf)) /** Jetty server that will serve our classes to worker nodes */ private var currentSettings: Settings = initialSettings var printResults = true // whether to print result lines var totalSilence = false // whether to print anything diff --git a/sbt/sbt-launch-lib.bash b/sbt/sbt-launch-lib.bash index 00a6b41013e5f..64e40a88206be 100755 --- a/sbt/sbt-launch-lib.bash +++ b/sbt/sbt-launch-lib.bash @@ -105,7 +105,7 @@ get_mem_opts () { local mem=${1:-2048} local perm=$(( $mem / 4 )) (( $perm > 256 )) || perm=256 - (( $perm < 1024 )) || perm=1024 + (( $perm < 4096 )) || perm=4096 local codecache=$(( $perm / 2 )) echo "-Xms${mem}m -Xmx${mem}m -XX:MaxPermSize=${perm}m -XX:ReservedCodeCacheSize=${codecache}m" diff --git a/streaming/pom.xml b/streaming/pom.xml index 91d6a1375a18c..2343e381e6f7c 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -37,10 +37,10 @@ a Hadoop 0.23.X issue --> yarn-alpha - - org.apache.avro - avro - + + org.apache.avro + avro + diff --git a/tools/pom.xml b/tools/pom.xml index b8dd255d40ac4..11433e596f5b0 100644 --- a/tools/pom.xml +++ b/tools/pom.xml @@ -36,10 +36,10 @@ yarn-alpha - - org.apache.avro - avro - + + org.apache.avro + avro + diff --git a/yarn/alpha/pom.xml b/yarn/alpha/pom.xml index bfe12ecec0c09..d0aeaceb0d23c 100644 --- a/yarn/alpha/pom.xml +++ b/yarn/alpha/pom.xml @@ -30,10 +30,10 @@ a Hadoop 0.23.X issue --> yarn-alpha - - org.apache.avro - avro - + + org.apache.avro + avro + diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index e045b9f0248f6..bb574f415293a 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -27,7 +27,6 @@ import scala.collection.JavaConversions._ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.net.NetUtils -import org.apache.hadoop.security.UserGroupInformation import org.apache.hadoop.util.ShutdownHookManager import org.apache.hadoop.yarn.api._ import org.apache.hadoop.yarn.api.records._ @@ -36,7 +35,7 @@ import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.ipc.YarnRPC import org.apache.hadoop.yarn.util.{ConverterUtils, Records} -import org.apache.spark.{SparkConf, SparkContext, Logging} +import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.util.Utils @@ -87,27 +86,16 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, isLastAMRetry = appAttemptId.getAttemptId() >= maxAppAttempts resourceManager = registerWithResourceManager() - // Workaround until hadoop moves to something which has - // https://issues.apache.org/jira/browse/HADOOP-8406 - fixed in (2.0.2-alpha but no 0.23 line) - // ignore result. - // This does not, unfortunately, always work reliably ... but alleviates the bug a lot of times - // Hence args.workerCores = numCore disabled above. Any better option? - - // Compute number of threads for akka - //val minimumMemory = appMasterResponse.getMinimumResourceCapability().getMemory() - //if (minimumMemory > 0) { - // val mem = args.workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD - // val numCore = (mem / minimumMemory) + (if (0 != (mem % minimumMemory)) 1 else 0) - - // if (numCore > 0) { - // do not override - hits https://issues.apache.org/jira/browse/HADOOP-8406 - // TODO: Uncomment when hadoop is on a version which has this fixed. - // args.workerCores = numCore - // } - //} - // org.apache.hadoop.io.compress.CompressionCodecFactory.getCodecClasses(conf) + // setup AmIpFilter for the SparkUI - do this before we start the UI + addAmIpFilter() ApplicationMaster.register(this) + + // Call this to force generation of secret so it gets populated into the + // hadoop UGI. This has to happen before the startUserClass which does a + // doAs in order for the credentials to be passed on to the worker containers. + val securityMgr = new SecurityManager(sparkConf) + // Start the user's JAR userThread = startUserClass() @@ -132,6 +120,20 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, System.exit(0) } + // add the yarn amIpFilter that Yarn requires for properly securing the UI + private def addAmIpFilter() { + val amFilter = "org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter" + System.setProperty("spark.ui.filters", amFilter) + val proxy = YarnConfiguration.getProxyHostAndPort(conf) + val parts : Array[String] = proxy.split(":") + val uriBase = "http://" + proxy + + System.getenv(ApplicationConstants.APPLICATION_WEB_PROXY_BASE_ENV) + + val params = "PROXY_HOST=" + parts(0) + "," + "PROXY_URI_BASE=" + uriBase + System.setProperty("spark.org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter.params", + params) + } + /** Get the Yarn approved local directories. */ private def getLocalDirs(): String = { // Hadoop 0.23 and 2.x have different Environment variable names for the diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala index 138c27910b0b0..b735d01df8097 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala @@ -29,7 +29,7 @@ import org.apache.hadoop.yarn.util.{ConverterUtils, Records} import akka.actor._ import akka.remote._ import akka.actor.Terminated -import org.apache.spark.{SparkConf, SparkContext, Logging} +import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext} import org.apache.spark.util.{Utils, AkkaUtils} import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend import org.apache.spark.scheduler.SplitInfo @@ -50,8 +50,9 @@ class WorkerLauncher(args: ApplicationMasterArguments, conf: Configuration, spar private var yarnAllocator: YarnAllocationHandler = _ private var driverClosed:Boolean = false + val securityManager = new SecurityManager(sparkConf) val actorSystem : ActorSystem = AkkaUtils.createActorSystem("sparkYarnAM", Utils.localHostName, 0, - conf = sparkConf)._1 + conf = sparkConf, securityManager = securityManager)._1 var actor: ActorRef = _ // This actor just working as a monitor to watch on Driver Actor. @@ -110,6 +111,7 @@ class WorkerLauncher(args: ApplicationMasterArguments, conf: Configuration, spar // we want to be reasonably responsive without causing too many requests to RM. val schedulerInterval = System.getProperty("spark.yarn.scheduler.heartbeat.interval-ms", "5000").toLong + // must be <= timeoutInterval / 2. val interval = math.min(timeoutInterval / 2, schedulerInterval) diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala index fe37168e5a7ba..1f894a677d169 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala @@ -129,12 +129,12 @@ class ClientArguments(val args: Array[String], val sparkConf: SparkConf) { System.err.println( "Usage: org.apache.spark.deploy.yarn.Client [options] \n" + "Options:\n" + - " --jar JAR_PATH Path to your application's JAR file (required in yarn-standalone mode)\n" + + " --jar JAR_PATH Path to your application's JAR file (required in yarn-cluster mode)\n" + " --class CLASS_NAME Name of your application's main class (required)\n" + " --args ARGS Arguments to be passed to your application's main class.\n" + " Mutliple invocations are possible, each will be passed in order.\n" + " --num-workers NUM Number of workers to start (Default: 2)\n" + - " --worker-cores NUM Number of cores for the workers (Default: 1). This is unsused right now.\n" + + " --worker-cores NUM Number of cores for the workers (Default: 1).\n" + " --master-class CLASS_NAME Class Name for Master (Default: spark.deploy.yarn.ApplicationMaster)\n" + " --master-memory MEM Memory for Master (e.g. 1000M, 2G) (Default: 512 Mb)\n" + " --worker-memory MEM Memory per Worker (e.g. 1000M, 2G) (Default: 1G)\n" + diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index d6c12a9f5952d..4c6e1dcd6dac3 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -17,11 +17,13 @@ package org.apache.spark.deploy.yarn -import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.hadoop.io.Text import org.apache.hadoop.mapred.JobConf +import org.apache.hadoop.security.Credentials import org.apache.hadoop.security.UserGroupInformation import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.conf.Configuration +import org.apache.spark.deploy.SparkHadoopUtil /** * Contains util methods to interact with Hadoop from spark. @@ -44,4 +46,24 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil { val jobCreds = conf.getCredentials() jobCreds.mergeAll(UserGroupInformation.getCurrentUser().getCredentials()) } + + override def getCurrentUserCredentials(): Credentials = { + UserGroupInformation.getCurrentUser().getCredentials() + } + + override def addCurrentUserCredentials(creds: Credentials) { + UserGroupInformation.getCurrentUser().addCredentials(creds) + } + + override def addSecretKeyToUserCredentials(key: String, secret: String) { + val creds = new Credentials() + creds.addSecretKey(new Text(key), secret.getBytes()) + addCurrentUserCredentials(creds) + } + + override def getSecretKeyFromUserCredentials(key: String): Array[Byte] = { + val credentials = getCurrentUserCredentials() + if (credentials != null) credentials.getSecretKey(new Text(key)) else null + } + } diff --git a/yarn/stable/pom.xml b/yarn/stable/pom.xml index 9d68603251d1c..e7915d12aef63 100644 --- a/yarn/stable/pom.xml +++ b/yarn/stable/pom.xml @@ -30,10 +30,10 @@ a Hadoop 0.23.X issue --> yarn-alpha - - org.apache.avro - avro - + + org.apache.avro + avro + diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index dd117d5810949..b48a2d50db5ef 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -27,7 +27,6 @@ import scala.collection.JavaConversions._ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.net.NetUtils -import org.apache.hadoop.security.UserGroupInformation import org.apache.hadoop.util.ShutdownHookManager import org.apache.hadoop.yarn.api._ import org.apache.hadoop.yarn.api.protocolrecords._ @@ -37,8 +36,9 @@ import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.ipc.YarnRPC import org.apache.hadoop.yarn.util.{ConverterUtils, Records} +import org.apache.hadoop.yarn.webapp.util.WebAppUtils; -import org.apache.spark.{SparkConf, SparkContext, Logging} +import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.util.Utils @@ -91,12 +91,16 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, amClient.init(yarnConf) amClient.start() - // Workaround until hadoop moves to something which has - // https://issues.apache.org/jira/browse/HADOOP-8406 - fixed in (2.0.2-alpha but no 0.23 line) - // org.apache.hadoop.io.compress.CompressionCodecFactory.getCodecClasses(conf) + // setup AmIpFilter for the SparkUI - do this before we start the UI + addAmIpFilter() ApplicationMaster.register(this) + // Call this to force generation of secret so it gets populated into the + // hadoop UGI. This has to happen before the startUserClass which does a + // doAs in order for the credentials to be passed on to the worker containers. + val securityMgr = new SecurityManager(sparkConf) + // Start the user's JAR userThread = startUserClass() @@ -121,6 +125,19 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, System.exit(0) } + // add the yarn amIpFilter that Yarn requires for properly securing the UI + private def addAmIpFilter() { + val amFilter = "org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter" + System.setProperty("spark.ui.filters", amFilter) + val proxy = WebAppUtils.getProxyHostAndPort(conf) + val parts : Array[String] = proxy.split(":") + val uriBase = "http://" + proxy + + System.getenv(ApplicationConstants.APPLICATION_WEB_PROXY_BASE_ENV) + + val params = "PROXY_HOST=" + parts(0) + "," + "PROXY_URI_BASE=" + uriBase + System.setProperty("spark.org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter.params", params) + } + /** Get the Yarn approved local directories. */ private def getLocalDirs(): String = { // Hadoop 0.23 and 2.x have different Environment variable names for the @@ -261,7 +278,6 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, val schedulerInterval = sparkConf.getLong("spark.yarn.scheduler.heartbeat.interval-ms", 5000) - // must be <= timeoutInterval / 2. val interval = math.min(timeoutInterval / 2, schedulerInterval) diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala index 40600f38e5e73..f1c1fea0b5895 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala @@ -28,7 +28,7 @@ import org.apache.hadoop.yarn.util.{ConverterUtils, Records} import akka.actor._ import akka.remote._ import akka.actor.Terminated -import org.apache.spark.{SparkConf, SparkContext, Logging} +import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext} import org.apache.spark.util.{Utils, AkkaUtils} import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend import org.apache.spark.scheduler.SplitInfo @@ -52,8 +52,9 @@ class WorkerLauncher(args: ApplicationMasterArguments, conf: Configuration, spar private var amClient: AMRMClient[ContainerRequest] = _ + val securityManager = new SecurityManager(sparkConf) val actorSystem: ActorSystem = AkkaUtils.createActorSystem("sparkYarnAM", Utils.localHostName, 0, - conf = sparkConf)._1 + conf = sparkConf, securityManager = securityManager)._1 var actor: ActorRef = _ // This actor just working as a monitor to watch on Driver Actor. @@ -105,6 +106,7 @@ class WorkerLauncher(args: ApplicationMasterArguments, conf: Configuration, spar val interval = math.min(timeoutInterval / 2, schedulerInterval) reporterThread = launchReporterThread(interval) + // Wait for the reporter thread to Finish. reporterThread.join()