|
17 | 17 |
|
18 | 18 | package org.apache.spark.mllib.recommendation |
19 | 19 |
|
| 20 | +import org.apache.spark.mllib.recommendation.ALS.BlockStats |
| 21 | +import org.apache.spark.mllib.util.MLlibTestSparkContext |
| 22 | +import org.apache.spark.storage.StorageLevel |
| 23 | +import org.jblas.DoubleMatrix |
| 24 | +import org.scalatest.FunSuite |
| 25 | + |
20 | 26 | import scala.collection.JavaConversions._ |
21 | 27 | import scala.math.abs |
22 | 28 | import scala.util.Random |
23 | 29 |
|
24 | | -import org.scalatest.FunSuite |
25 | | -import org.jblas.DoubleMatrix |
26 | | - |
27 | | -import org.apache.spark.SparkContext._ |
28 | | -import org.apache.spark.mllib.util.MLlibTestSparkContext |
29 | | -import org.apache.spark.mllib.recommendation.ALS.BlockStats |
30 | | - |
31 | 30 | object ALSSuite { |
32 | 31 |
|
33 | 32 | def generateRatingsAsJavaList( |
@@ -139,6 +138,32 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext { |
139 | 138 | assert(u11 != u2) |
140 | 139 | } |
141 | 140 |
|
| 141 | + test("Storage Level for RDDs in model") { |
| 142 | + val ratings = sc.parallelize(ALSSuite.generateRatings(10, 20, 5, 0.5, false, false)._1, 2) |
| 143 | + var storageLevel = StorageLevel.MEMORY_ONLY |
| 144 | + var model = new ALS() |
| 145 | + .setRank(5) |
| 146 | + .setIterations(1) |
| 147 | + .setLambda(1.0) |
| 148 | + .setBlocks(2) |
| 149 | + .setSeed(1) |
| 150 | + .setFinalRDDStorageLevel(storageLevel) |
| 151 | + .run(ratings) |
| 152 | + assert(model.productFeatures.getStorageLevel == storageLevel); |
| 153 | + assert(model.userFeatures.getStorageLevel == storageLevel); |
| 154 | + storageLevel = StorageLevel.DISK_ONLY |
| 155 | + model = new ALS() |
| 156 | + .setRank(5) |
| 157 | + .setIterations(1) |
| 158 | + .setLambda(1.0) |
| 159 | + .setBlocks(2) |
| 160 | + .setSeed(1) |
| 161 | + .setFinalRDDStorageLevel(storageLevel) |
| 162 | + .run(ratings) |
| 163 | + assert(model.productFeatures.getStorageLevel == storageLevel); |
| 164 | + assert(model.userFeatures.getStorageLevel == storageLevel); |
| 165 | + } |
| 166 | + |
142 | 167 | test("negative ids") { |
143 | 168 | val data = ALSSuite.generateRatings(50, 50, 2, 0.7, false, false) |
144 | 169 | val ratings = sc.parallelize(data._1.map { case Rating(u, p, r) => |
|
0 commit comments