Skip to content

Commit 0cba802

Browse files
committed
[SPARK-5814][MLLIB][GRAPHX] Remove JBLAS from runtime
The issue is discussed in https://issues.apache.org/jira/browse/SPARK-5669. Replacing all JBLAS usage by netlib-java gives us a simpler dependency tree and less license issues to worry about. I didn't touch the test scope in this PR. The user guide is not modified to avoid merge conflicts with branch-1.3. srowen ankurdave pwendell Author: Xiangrui Meng <[email protected]> Closes #4699 from mengxr/SPARK-5814 and squashes the following commits: 48635c6 [Xiangrui Meng] move netlib-java version to parent pom ca21c74 [Xiangrui Meng] remove jblas from ml-guide 5f7767a [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-5814 c5c4183 [Xiangrui Meng] merge master 0f20cad [Xiangrui Meng] add mima excludes e53e9f4 [Xiangrui Meng] remove jblas from mllib runtime ceaa14d [Xiangrui Meng] replace jblas by netlib-java in graphx fa7c2ca [Xiangrui Meng] move jblas to test scope
1 parent 712679a commit 0cba802

File tree

16 files changed

+183
-144
lines changed

16 files changed

+183
-144
lines changed

assembly/pom.xml

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -114,16 +114,6 @@
114114
<exclude>META-INF/*.RSA</exclude>
115115
</excludes>
116116
</filter>
117-
<filter>
118-
<!-- Exclude libgfortran, libgcc for license issues -->
119-
<artifact>org.jblas:jblas</artifact>
120-
<excludes>
121-
<!-- Linux amd64 is OK; not statically linked -->
122-
<exclude>lib/static/Linux/i386/**</exclude>
123-
<exclude>lib/static/Mac OS X/**</exclude>
124-
<exclude>lib/static/Windows/**</exclude>
125-
</excludes>
126-
</filter>
127117
</filters>
128118
</configuration>
129119
<executions>

docs/mllib-guide.md

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,11 +80,6 @@ include `netlib-java`'s native proxies by default. To configure
8080
[netlib-java](https://github.com/fommil/netlib-java) documentation for
8181
your platform's additional installation instructions.
8282

83-
MLlib also uses [jblas](https://github.com/mikiobraun/jblas) which
84-
will require you to install the
85-
[gfortran runtime library](https://github.com/mikiobraun/jblas/wiki/Missing-Libraries)
86-
if it is not already present on your nodes.
87-
8883
To use MLlib in Python, you will need [NumPy](http://www.numpy.org)
8984
version 1.4 or newer.
9085

graphx/pom.xml

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,14 @@
4545
<artifactId>guava</artifactId>
4646
</dependency>
4747
<dependency>
48-
<groupId>org.jblas</groupId>
49-
<artifactId>jblas</artifactId>
50-
<version>${jblas.version}</version>
48+
<groupId>com.github.fommil.netlib</groupId>
49+
<artifactId>core</artifactId>
50+
<version>${netlib.java.version}</version>
51+
</dependency>
52+
<dependency>
53+
<groupId>net.sourceforge.f2j</groupId>
54+
<artifactId>arpack_combined_all</artifactId>
55+
<version>0.1</version>
5156
</dependency>
5257
<dependency>
5358
<groupId>org.scalacheck</groupId>

graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala

Lines changed: 59 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
package org.apache.spark.graphx.lib
1919

2020
import scala.util.Random
21-
import org.jblas.DoubleMatrix
21+
22+
import com.github.fommil.netlib.BLAS.{getInstance => blas}
23+
2224
import org.apache.spark.rdd._
2325
import org.apache.spark.graphx._
2426

@@ -53,7 +55,7 @@ object SVDPlusPlus {
5355
* a Multifaceted Collaborative Filtering Model",
5456
* available at [[http://public.research.att.com/~volinsky/netflix/kdd08koren.pdf]].
5557
*
56-
* The prediction rule is rui = u + bu + bi + qi*(pu + |N(u)|^(-0.5)*sum(y)),
58+
* The prediction rule is rui = u + bu + bi + qi*(pu + |N(u)|^^-0.5^^*sum(y)),
5759
* see the details on page 6.
5860
*
5961
* @param edges edges for constructing the graph
@@ -66,13 +68,10 @@ object SVDPlusPlus {
6668
: (Graph[(Array[Double], Array[Double], Double, Double), Double], Double) =
6769
{
6870
// Generate default vertex attribute
69-
def defaultF(rank: Int): (DoubleMatrix, DoubleMatrix, Double, Double) = {
70-
val v1 = new DoubleMatrix(rank)
71-
val v2 = new DoubleMatrix(rank)
72-
for (i <- 0 until rank) {
73-
v1.put(i, Random.nextDouble())
74-
v2.put(i, Random.nextDouble())
75-
}
71+
def defaultF(rank: Int): (Array[Double], Array[Double], Double, Double) = {
72+
// TODO: use a fixed random seed
73+
val v1 = Array.fill(rank)(Random.nextDouble())
74+
val v2 = Array.fill(rank)(Random.nextDouble())
7675
(v1, v2, 0.0, 0.0)
7776
}
7877

@@ -92,7 +91,7 @@ object SVDPlusPlus {
9291
(g1, g2) => (g1._1 + g2._1, g1._2 + g2._2))
9392

9493
val gJoinT0 = g.outerJoinVertices(t0) {
95-
(vid: VertexId, vd: (DoubleMatrix, DoubleMatrix, Double, Double),
94+
(vid: VertexId, vd: (Array[Double], Array[Double], Double, Double),
9695
msg: Option[(Long, Double)]) =>
9796
(vd._1, vd._2, msg.get._2 / msg.get._1, 1.0 / scala.math.sqrt(msg.get._1))
9897
}.cache()
@@ -102,39 +101,52 @@ object SVDPlusPlus {
102101

103102
def sendMsgTrainF(conf: Conf, u: Double)
104103
(ctx: EdgeContext[
105-
(DoubleMatrix, DoubleMatrix, Double, Double),
104+
(Array[Double], Array[Double], Double, Double),
106105
Double,
107-
(DoubleMatrix, DoubleMatrix, Double)]) {
106+
(Array[Double], Array[Double], Double)]) {
108107
val (usr, itm) = (ctx.srcAttr, ctx.dstAttr)
109108
val (p, q) = (usr._1, itm._1)
110-
var pred = u + usr._3 + itm._3 + q.dot(usr._2)
109+
val rank = p.length
110+
var pred = u + usr._3 + itm._3 + blas.ddot(rank, q, 1, usr._2, 1)
111111
pred = math.max(pred, conf.minVal)
112112
pred = math.min(pred, conf.maxVal)
113113
val err = ctx.attr - pred
114-
val updateP = q.mul(err)
115-
.subColumnVector(p.mul(conf.gamma7))
116-
.mul(conf.gamma2)
117-
val updateQ = usr._2.mul(err)
118-
.subColumnVector(q.mul(conf.gamma7))
119-
.mul(conf.gamma2)
120-
val updateY = q.mul(err * usr._4)
121-
.subColumnVector(itm._2.mul(conf.gamma7))
122-
.mul(conf.gamma2)
114+
// updateP = (err * q - conf.gamma7 * p) * conf.gamma2
115+
val updateP = q.clone()
116+
blas.dscal(rank, err * conf.gamma2, updateP, 1)
117+
blas.daxpy(rank, -conf.gamma7 * conf.gamma2, p, 1, updateP, 1)
118+
// updateQ = (err * usr._2 - conf.gamma7 * q) * conf.gamma2
119+
val updateQ = usr._2.clone()
120+
blas.dscal(rank, err * conf.gamma2, updateQ, 1)
121+
blas.daxpy(rank, -conf.gamma7 * conf.gamma2, q, 1, updateQ, 1)
122+
// updateY = (err * usr._4 * q - conf.gamma7 * itm._2) * conf.gamma2
123+
val updateY = q.clone()
124+
blas.dscal(rank, err * usr._4 * conf.gamma2, updateY, 1)
125+
blas.daxpy(rank, -conf.gamma7 * conf.gamma2, itm._2, 1, updateY, 1)
123126
ctx.sendToSrc((updateP, updateY, (err - conf.gamma6 * usr._3) * conf.gamma1))
124127
ctx.sendToDst((updateQ, updateY, (err - conf.gamma6 * itm._3) * conf.gamma1))
125128
}
126129

127130
for (i <- 0 until conf.maxIters) {
128131
// Phase 1, calculate pu + |N(u)|^(-0.5)*sum(y) for user nodes
129132
g.cache()
130-
val t1 = g.aggregateMessages[DoubleMatrix](
133+
val t1 = g.aggregateMessages[Array[Double]](
131134
ctx => ctx.sendToSrc(ctx.dstAttr._2),
132-
(g1, g2) => g1.addColumnVector(g2))
135+
(g1, g2) => {
136+
val out = g1.clone()
137+
blas.daxpy(out.length, 1.0, g2, 1, out, 1)
138+
out
139+
})
133140
val gJoinT1 = g.outerJoinVertices(t1) {
134-
(vid: VertexId, vd: (DoubleMatrix, DoubleMatrix, Double, Double),
135-
msg: Option[DoubleMatrix]) =>
136-
if (msg.isDefined) (vd._1, vd._1
137-
.addColumnVector(msg.get.mul(vd._4)), vd._3, vd._4) else vd
141+
(vid: VertexId, vd: (Array[Double], Array[Double], Double, Double),
142+
msg: Option[Array[Double]]) =>
143+
if (msg.isDefined) {
144+
val out = vd._1.clone()
145+
blas.daxpy(out.length, vd._4, msg.get, 1, out, 1)
146+
(vd._1, out, vd._3, vd._4)
147+
} else {
148+
vd
149+
}
138150
}.cache()
139151
materialize(gJoinT1)
140152
g.unpersist()
@@ -144,14 +156,24 @@ object SVDPlusPlus {
144156
g.cache()
145157
val t2 = g.aggregateMessages(
146158
sendMsgTrainF(conf, u),
147-
(g1: (DoubleMatrix, DoubleMatrix, Double), g2: (DoubleMatrix, DoubleMatrix, Double)) =>
148-
(g1._1.addColumnVector(g2._1), g1._2.addColumnVector(g2._2), g1._3 + g2._3))
159+
(g1: (Array[Double], Array[Double], Double), g2: (Array[Double], Array[Double], Double)) =>
160+
{
161+
val out1 = g1._1.clone()
162+
blas.daxpy(out1.length, 1.0, g2._1, 1, out1, 1)
163+
val out2 = g2._2.clone()
164+
blas.daxpy(out2.length, 1.0, g2._2, 1, out2, 1)
165+
(out1, out2, g1._3 + g2._3)
166+
})
149167
val gJoinT2 = g.outerJoinVertices(t2) {
150168
(vid: VertexId,
151-
vd: (DoubleMatrix, DoubleMatrix, Double, Double),
152-
msg: Option[(DoubleMatrix, DoubleMatrix, Double)]) =>
153-
(vd._1.addColumnVector(msg.get._1), vd._2.addColumnVector(msg.get._2),
154-
vd._3 + msg.get._3, vd._4)
169+
vd: (Array[Double], Array[Double], Double, Double),
170+
msg: Option[(Array[Double], Array[Double], Double)]) => {
171+
val out1 = vd._1.clone()
172+
blas.daxpy(out1.length, 1.0, msg.get._1, 1, out1, 1)
173+
val out2 = vd._2.clone()
174+
blas.daxpy(out2.length, 1.0, msg.get._2, 1, out2, 1)
175+
(out1, out2, vd._3 + msg.get._3, vd._4)
176+
}
155177
}.cache()
156178
materialize(gJoinT2)
157179
g.unpersist()
@@ -160,10 +182,10 @@ object SVDPlusPlus {
160182

161183
// calculate error on training set
162184
def sendMsgTestF(conf: Conf, u: Double)
163-
(ctx: EdgeContext[(DoubleMatrix, DoubleMatrix, Double, Double), Double, Double]) {
185+
(ctx: EdgeContext[(Array[Double], Array[Double], Double, Double), Double, Double]) {
164186
val (usr, itm) = (ctx.srcAttr, ctx.dstAttr)
165187
val (p, q) = (usr._1, itm._1)
166-
var pred = u + usr._3 + itm._3 + q.dot(usr._2)
188+
var pred = u + usr._3 + itm._3 + blas.ddot(q.length, q, 1, usr._2, 1)
167189
pred = math.max(pred, conf.minVal)
168190
pred = math.min(pred, conf.maxVal)
169191
val err = (ctx.attr - pred) * (ctx.attr - pred)
@@ -173,7 +195,7 @@ object SVDPlusPlus {
173195
g.cache()
174196
val t3 = g.aggregateMessages[Double](sendMsgTestF(conf, u), _ + _)
175197
val gJoinT3 = g.outerJoinVertices(t3) {
176-
(vid: VertexId, vd: (DoubleMatrix, DoubleMatrix, Double, Double), msg: Option[Double]) =>
198+
(vid: VertexId, vd: (Array[Double], Array[Double], Double, Double), msg: Option[Double]) =>
177199
if (msg.isDefined) (vd._1, vd._2, vd._3, msg.get) else vd
178200
}.cache()
179201
materialize(gJoinT3)

graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,11 @@ class SVDPlusPlusSuite extends FunSuite with LocalSparkContext {
3232
Edge(fields(0).toLong * 2, fields(1).toLong * 2 + 1, fields(2).toDouble)
3333
}
3434
val conf = new SVDPlusPlus.Conf(10, 2, 0.0, 5.0, 0.007, 0.007, 0.005, 0.015) // 2 iterations
35-
var (graph, u) = SVDPlusPlus.runSVDPlusPlus(edges, conf)
35+
val (graph, _) = SVDPlusPlus.run(edges, conf)
3636
graph.cache()
37-
val err = graph.vertices.collect().map{ case (vid, vd) =>
37+
val err = graph.vertices.map { case (vid, vd) =>
3838
if (vid % 2 == 1) vd._4 else 0.0
39-
}.reduce(_ + _) / graph.triplets.collect().size
39+
}.reduce(_ + _) / graph.numEdges
4040
assert(err <= svdppErr)
4141
}
4242
}

mllib/pom.xml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
<groupId>org.jblas</groupId>
6060
<artifactId>jblas</artifactId>
6161
<version>${jblas.version}</version>
62+
<scope>test</scope>
6263
</dependency>
6364
<dependency>
6465
<groupId>org.scalanlp</groupId>
@@ -116,7 +117,7 @@
116117
<dependency>
117118
<groupId>com.github.fommil.netlib</groupId>
118119
<artifactId>all</artifactId>
119-
<version>1.1.2</version>
120+
<version>${netlib.java.version}</version>
120121
<type>pom</type>
121122
</dependency>
122123
</dependencies>

mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ import scala.util.hashing.byteswap64
2626

2727
import com.github.fommil.netlib.BLAS.{getInstance => blas}
2828
import com.github.fommil.netlib.LAPACK.{getInstance => lapack}
29-
import org.jblas.DoubleMatrix
3029
import org.netlib.util.intW
3130

3231
import org.apache.spark.{Logging, Partitioner}
@@ -361,14 +360,14 @@ object ALS extends Logging {
361360
private[recommendation] class NNLSSolver extends LeastSquaresNESolver {
362361
private var rank: Int = -1
363362
private var workspace: NNLS.Workspace = _
364-
private var ata: DoubleMatrix = _
363+
private var ata: Array[Double] = _
365364
private var initialized: Boolean = false
366365

367366
private def initialize(rank: Int): Unit = {
368367
if (!initialized) {
369368
this.rank = rank
370369
workspace = NNLS.createWorkspace(rank)
371-
ata = new DoubleMatrix(rank, rank)
370+
ata = new Array[Double](rank * rank)
372371
initialized = true
373372
} else {
374373
require(this.rank == rank)
@@ -385,7 +384,7 @@ object ALS extends Logging {
385384
val rank = ne.k
386385
initialize(rank)
387386
fillAtA(ne.ata, lambda * ne.n)
388-
val x = NNLS.solve(ata, new DoubleMatrix(rank, 1, ne.atb: _*), workspace)
387+
val x = NNLS.solve(ata, ne.atb, workspace)
389388
ne.reset()
390389
x.map(x => x.toFloat)
391390
}
@@ -398,17 +397,16 @@ object ALS extends Logging {
398397
var i = 0
399398
var pos = 0
400399
var a = 0.0
401-
val data = ata.data
402400
while (i < rank) {
403401
var j = 0
404402
while (j <= i) {
405403
a = triAtA(pos)
406-
data(i * rank + j) = a
407-
data(j * rank + i) = a
404+
ata(i * rank + j) = a
405+
ata(j * rank + i) = a
408406
pos += 1
409407
j += 1
410408
}
411-
data(i * rank + i) += lambda
409+
ata(i * rank + i) += lambda
412410
i += 1
413411
}
414412
}

0 commit comments

Comments
 (0)