diff --git a/common/utils/src/main/scala/org/apache/spark/util/SparkEnvUtils.scala b/common/utils/src/main/scala/org/apache/spark/util/SparkEnvUtils.scala index 01e3f52de41f3..2a82bbbebeb2a 100644 --- a/common/utils/src/main/scala/org/apache/spark/util/SparkEnvUtils.scala +++ b/common/utils/src/main/scala/org/apache/spark/util/SparkEnvUtils.scala @@ -25,6 +25,10 @@ private[spark] trait SparkEnvUtils { */ def isTesting: Boolean = JavaUtils.isTesting + /** + * Whether allow using native BLAS/LAPACK/ARPACK libraries if available. + */ + val allowNativeBlas = "true".equals(System.getProperty("netlib.allowNativeBlas", "true")) } object SparkEnvUtils extends SparkEnvUtils diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index b0ac6d96a0010..e57d304685efc 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -432,6 +432,7 @@ class SparkContext(config: SparkConf) extends Logging { SparkContext.supplementJavaModuleOptions(_conf) SparkContext.supplementJavaIPv6Options(_conf) + SparkContext.supplementBlasOptions(_conf) _driverLogger = DriverLogger(_conf) @@ -3414,26 +3415,30 @@ object SparkContext extends Logging { } } + private def supplementJavaOpts(conf: SparkConf, key: String, javaOpts: String): Unit = { + conf.set(key, s"$javaOpts ${conf.get(key, "")}".trim()) + } + /** * SPARK-36796: This is a helper function to supplement some JVM runtime options to * `spark.driver.extraJavaOptions` and `spark.executor.extraJavaOptions`. */ private def supplementJavaModuleOptions(conf: SparkConf): Unit = { - def supplement(key: String): Unit = { - val v = s"${JavaModuleOptions.defaultModuleOptions()} ${conf.get(key, "")}".trim() - conf.set(key, v) - } - supplement(SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS) - supplement(SparkLauncher.EXECUTOR_EXTRA_JAVA_OPTIONS) + val opts = JavaModuleOptions.defaultModuleOptions() + supplementJavaOpts(conf, SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS, opts) + supplementJavaOpts(conf, SparkLauncher.EXECUTOR_EXTRA_JAVA_OPTIONS, opts) } private def supplementJavaIPv6Options(conf: SparkConf): Unit = { - def supplement(key: String): Unit = { - val v = s"-Djava.net.preferIPv6Addresses=${Utils.preferIPv6} ${conf.get(key, "")}".trim() - conf.set(key, v) - } - supplement(SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS) - supplement(SparkLauncher.EXECUTOR_EXTRA_JAVA_OPTIONS) + val opts = s"-Djava.net.preferIPv6Addresses=${Utils.preferIPv6}" + supplementJavaOpts(conf, SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS, opts) + supplementJavaOpts(conf, SparkLauncher.EXECUTOR_EXTRA_JAVA_OPTIONS, opts) + } + + private def supplementBlasOptions(conf: SparkConf): Unit = { + val opts = s"-Dnetlib.allowNativeBlas=${Utils.allowNativeBlas}" + supplementJavaOpts(conf, SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS, opts) + supplementJavaOpts(conf, SparkLauncher.EXECUTOR_EXTRA_JAVA_OPTIONS, opts) } } diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index d0f4806c49482..21ea53f7a721b 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -2912,4 +2912,12 @@ package object config { .checkValue(v => v.forall(Set("stdout", "stderr").contains), "The value only can be one or more of 'stdout, stderr'.") .createWithDefault(Seq("stdout", "stderr")) + + private[spark] val SPARK_ML_ALLOW_NATIVE_BLAS = + ConfigBuilder("spark.ml.allowNativeBlas") + .doc("Whether allow using native BLAS/LAPACK/ARPACK implementations when native " + + "libraries are available. If disabled, always use Java implementations.") + .version("4.1.0") + .booleanConf + .createWithDefault(true) } diff --git a/docs/ml-linalg-guide.md b/docs/ml-linalg-guide.md index 6e91d81f49760..aa1471f0df995 100644 --- a/docs/ml-linalg-guide.md +++ b/docs/ml-linalg-guide.md @@ -46,8 +46,7 @@ The installation should be done on all nodes of the cluster. Generic version of For Debian / Ubuntu: ``` -sudo apt-get install libopenblas-base -sudo update-alternatives --config libblas.so.3 +sudo apt-get install libopenblas-dev ``` For CentOS / RHEL: ``` @@ -76,6 +75,8 @@ You can also point `dev.ludovic.netlib` to specific libraries names and paths. F If native libraries are not properly configured in the system, the Java implementation (javaBLAS) will be used as fallback option. +You can also set spark conf `spark.ml.allowNativeBlas` or Java system property `netlib.allowNativeBlas` to `false` to disable native BLAS and always use the Java implementation. + ## Spark Configuration The default behavior of multi-threading in either Intel MKL or OpenBLAS may not be optimal with Spark's execution model [^1]. diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java index 7b9f90ac7b7a6..477eb470c0577 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java @@ -352,6 +352,11 @@ private List buildSparkSubmitCommand(Map env) config.get(SparkLauncher.DRIVER_EXTRA_LIBRARY_PATH)); } + if (config.containsKey("spark.ml.allowNativeBlas")) { + String allowNativeBlas = config.get("spark.ml.allowNativeBlas"); + addOptionString(cmd, "-Dnetlib.allowNativeBlas=" + allowNativeBlas); + } + // SPARK-36796: Always add some JVM runtime default options to submit command addOptionString(cmd, JavaModuleOptions.defaultModuleOptions()); addOptionString(cmd, "-Dderby.connection.requireAuthentication=false"); diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala index d07eb890dc325..e9b8a6656bfc9 100644 --- a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala +++ b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala @@ -39,8 +39,12 @@ private[spark] object BLAS extends Serializable { // For level-3 routines, we use the native BLAS. private[spark] def nativeBLAS: NetlibBLAS = { if (_nativeBLAS == null) { - _nativeBLAS = + // Replica SparkEnvUtils.allowNativeBlas to avoid pulling commons/utils as dependency + _nativeBLAS = if ("true".equals(System.getProperty("netlib.allowNativeBlas", "true"))) { try { NetlibNativeBLAS.getInstance } catch { case _: Throwable => javaBLAS } + } else { + javaBLAS + } } _nativeBLAS } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/ARPACK.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/ARPACK.scala index 23e514b3a2677..5f109754de23d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/ARPACK.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/ARPACK.scala @@ -19,10 +19,13 @@ package org.apache.spark.mllib.linalg import dev.ludovic.netlib.arpack.{ARPACK => NetlibARPACK, JavaARPACK => NetlibJavaARPACK, NativeARPACK => NetlibNativeARPACK} +import org.apache.spark.internal.Logging +import org.apache.spark.util.SparkEnvUtils + /** * ARPACK routines for MLlib's vectors and matrices. */ -private[spark] object ARPACK extends Serializable { +private[spark] object ARPACK extends Serializable with Logging { @transient private var _javaARPACK: NetlibARPACK = _ @transient private var _nativeARPACK: NetlibARPACK = _ @@ -36,8 +39,12 @@ private[spark] object ARPACK extends Serializable { private[spark] def nativeARPACK: NetlibARPACK = { if (_nativeARPACK == null) { - _nativeARPACK = + _nativeARPACK = if (SparkEnvUtils.allowNativeBlas) { try { NetlibNativeARPACK.getInstance } catch { case _: Throwable => javaARPACK } + } else { + logInfo("Disable native ARPACK because netlib.allowNativeBlas is false.") + javaARPACK + } } _nativeARPACK } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala index 637380752c1db..753d80e5a69b9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala @@ -17,44 +17,14 @@ package org.apache.spark.mllib.linalg -import dev.ludovic.netlib.blas.{BLAS => NetlibBLAS, JavaBLAS => NetlibJavaBLAS, NativeBLAS => NetlibNativeBLAS} - import org.apache.spark.internal.Logging +import org.apache.spark.ml.linalg.BLAS.{getBLAS, nativeBLAS} /** * BLAS routines for MLlib's vectors and matrices. */ private[spark] object BLAS extends Serializable with Logging { - @transient private var _javaBLAS: NetlibBLAS = _ - @transient private var _nativeBLAS: NetlibBLAS = _ - private val nativeL1Threshold: Int = 256 - - // For level-1 function dspmv, use javaBLAS for better performance. - private[spark] def javaBLAS: NetlibBLAS = { - if (_javaBLAS == null) { - _javaBLAS = NetlibJavaBLAS.getInstance - } - _javaBLAS - } - - // For level-3 routines, we use the native BLAS. - private[spark] def nativeBLAS: NetlibBLAS = { - if (_nativeBLAS == null) { - _nativeBLAS = - try { NetlibNativeBLAS.getInstance } catch { case _: Throwable => javaBLAS } - } - _nativeBLAS - } - - private[spark] def getBLAS(vectorSize: Int): NetlibBLAS = { - if (vectorSize < nativeL1Threshold) { - javaBLAS - } else { - nativeBLAS - } - } - /** * y += a * x */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/LAPACK.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/LAPACK.scala index 9ce60c3b396bc..074df1dc3f9bc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/LAPACK.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/LAPACK.scala @@ -19,10 +19,13 @@ package org.apache.spark.mllib.linalg import dev.ludovic.netlib.lapack.{JavaLAPACK => NetlibJavaLAPACK, LAPACK => NetlibLAPACK, NativeLAPACK => NetlibNativeLAPACK} +import org.apache.spark.internal.Logging +import org.apache.spark.util.SparkEnvUtils + /** * LAPACK routines for MLlib's vectors and matrices. */ -private[spark] object LAPACK extends Serializable { +private[spark] object LAPACK extends Serializable with Logging { @transient private var _javaLAPACK: NetlibLAPACK = _ @transient private var _nativeLAPACK: NetlibLAPACK = _ @@ -36,8 +39,12 @@ private[spark] object LAPACK extends Serializable { private[spark] def nativeLAPACK: NetlibLAPACK = { if (_nativeLAPACK == null) { - _nativeLAPACK = + _nativeLAPACK = if (SparkEnvUtils.allowNativeBlas) { try { NetlibNativeLAPACK.getInstance } catch { case _: Throwable => javaLAPACK } + } else { + logInfo("Disable native LAPACK because netlib.allowNativeBlas is false.") + javaLAPACK + } } _nativeLAPACK } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index a0b31d3d9282b..fecacc8d3b000 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -427,7 +427,7 @@ class DenseMatrix @Since("1.3.0") ( if (isTransposed) { Iterator.tabulate(numCols) { j => val col = new Array[Double](numRows) - BLAS.nativeBLAS.dcopy(numRows, values, j, numCols, col, 0, 1) + newlinalg.BLAS.nativeBLAS.dcopy(numRows, values, j, numCols, col, 0, 1) new DenseVector(col) } } else { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala index 9d239b5b7a503..d8afaacc0a6dd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala @@ -31,7 +31,7 @@ import org.apache.spark.SparkContext import org.apache.spark.annotation.Since import org.apache.spark.api.java.{JavaPairRDD, JavaRDD} import org.apache.spark.internal.{Logging, LogKeys} -import org.apache.spark.mllib.linalg.BLAS +import org.apache.spark.ml.linalg.BLAS import org.apache.spark.mllib.rdd.MLPairRDDFunctions._ import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala index aadf535e137bd..1eb6c9dad1564 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala @@ -23,12 +23,6 @@ import org.apache.spark.mllib.util.TestingUtils._ class BLASSuite extends SparkFunSuite { - test("nativeL1Threshold") { - assert(getBLAS(128) == BLAS.javaBLAS) - assert(getBLAS(256) == BLAS.nativeBLAS) - assert(getBLAS(512) == BLAS.nativeBLAS) - } - test("copy") { val sx = Vectors.sparse(4, Array(0, 2), Array(1.0, -2.0)) val dx = Vectors.dense(1.0, 0.0, -2.0, 0.0) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 09655007acb34..92eea27a9fd3f 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -1049,6 +1049,8 @@ private[spark] class Client( javaOpts += s"-Djava.net.preferIPv6Addresses=${Utils.preferIPv6}" + javaOpts += s"-Dnetlib.allowNativeBlas=${sparkConf.get(SPARK_ML_ALLOW_NATIVE_BLAS)}" + // SPARK-37106: To start AM with Java 17, `JavaModuleOptions.defaultModuleOptions` // is added by default. javaOpts += JavaModuleOptions.defaultModuleOptions()