Skip to content

Commit 72df5a3

Browse files
zeitosmengxr
authored andcommitted
SPARK-5148 [MLlib] Make usersOut/productsOut storagelevel in ALS configurable
Author: Fernando Otero (ZeoS) <[email protected]> Closes #3953 from zeitos/storageLevel and squashes the following commits: 0f070b9 [Fernando Otero (ZeoS)] fix imports 6869e80 [Fernando Otero (ZeoS)] fix comment length 90c9f7e [Fernando Otero (ZeoS)] fix comment length 18a992e [Fernando Otero (ZeoS)] changing storage level
1 parent 538f221 commit 72df5a3

File tree

2 files changed

+43
-2
lines changed
  • mllib/src
    • main/scala/org/apache/spark/mllib/recommendation
    • test/scala/org/apache/spark/mllib/recommendation

2 files changed

+43
-2
lines changed

mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ class ALS private (
116116

117117
/** storage level for user/product in/out links */
118118
private var intermediateRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK
119+
private var finalRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK
119120

120121
/**
121122
* Set the number of blocks for both user blocks and product blocks to parallelize the computation
@@ -204,6 +205,19 @@ class ALS private (
204205
this
205206
}
206207

208+
/**
209+
* :: DeveloperApi ::
210+
* Sets storage level for final RDDs (user/product used in MatrixFactorizationModel). The default
211+
* value is `MEMORY_AND_DISK`. Users can change it to a serialized storage, e.g.
212+
* `MEMORY_AND_DISK_SER` and set `spark.rdd.compress` to `true` to reduce the space requirement,
213+
* at the cost of speed.
214+
*/
215+
@DeveloperApi
216+
def setFinalRDDStorageLevel(storageLevel: StorageLevel): this.type = {
217+
this.finalRDDStorageLevel = storageLevel
218+
this
219+
}
220+
207221
/**
208222
* Run ALS with the configured parameters on an input RDD of (user, product, rating) triples.
209223
* Returns a MatrixFactorizationModel with feature vectors for each user and product.
@@ -307,8 +321,8 @@ class ALS private (
307321
val usersOut = unblockFactors(users, userOutLinks)
308322
val productsOut = unblockFactors(products, productOutLinks)
309323

310-
usersOut.setName("usersOut").persist(StorageLevel.MEMORY_AND_DISK)
311-
productsOut.setName("productsOut").persist(StorageLevel.MEMORY_AND_DISK)
324+
usersOut.setName("usersOut").persist(finalRDDStorageLevel)
325+
productsOut.setName("productsOut").persist(finalRDDStorageLevel)
312326

313327
// Materialize usersOut and productsOut.
314328
usersOut.count()

mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import org.jblas.DoubleMatrix
2727
import org.apache.spark.SparkContext._
2828
import org.apache.spark.mllib.util.MLlibTestSparkContext
2929
import org.apache.spark.mllib.recommendation.ALS.BlockStats
30+
import org.apache.spark.storage.StorageLevel
3031

3132
object ALSSuite {
3233

@@ -139,6 +140,32 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext {
139140
assert(u11 != u2)
140141
}
141142

143+
test("Storage Level for RDDs in model") {
144+
val ratings = sc.parallelize(ALSSuite.generateRatings(10, 20, 5, 0.5, false, false)._1, 2)
145+
var storageLevel = StorageLevel.MEMORY_ONLY
146+
var model = new ALS()
147+
.setRank(5)
148+
.setIterations(1)
149+
.setLambda(1.0)
150+
.setBlocks(2)
151+
.setSeed(1)
152+
.setFinalRDDStorageLevel(storageLevel)
153+
.run(ratings)
154+
assert(model.productFeatures.getStorageLevel == storageLevel);
155+
assert(model.userFeatures.getStorageLevel == storageLevel);
156+
storageLevel = StorageLevel.DISK_ONLY
157+
model = new ALS()
158+
.setRank(5)
159+
.setIterations(1)
160+
.setLambda(1.0)
161+
.setBlocks(2)
162+
.setSeed(1)
163+
.setFinalRDDStorageLevel(storageLevel)
164+
.run(ratings)
165+
assert(model.productFeatures.getStorageLevel == storageLevel);
166+
assert(model.userFeatures.getStorageLevel == storageLevel);
167+
}
168+
142169
test("negative ids") {
143170
val data = ALSSuite.generateRatings(50, 50, 2, 0.7, false, false)
144171
val ratings = sc.parallelize(data._1.map { case Rating(u, p, r) =>

0 commit comments

Comments
 (0)