Skip to content

Commit 6869e80

Browse files
committed
fix comment length
1 parent 90c9f7e commit 6869e80

File tree

1 file changed

+32
-7
lines changed
  • mllib/src/test/scala/org/apache/spark/mllib/recommendation

1 file changed

+32
-7
lines changed

mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,16 @@
1717

1818
package org.apache.spark.mllib.recommendation
1919

20+
import org.apache.spark.mllib.recommendation.ALS.BlockStats
21+
import org.apache.spark.mllib.util.MLlibTestSparkContext
22+
import org.apache.spark.storage.StorageLevel
23+
import org.jblas.DoubleMatrix
24+
import org.scalatest.FunSuite
25+
2026
import scala.collection.JavaConversions._
2127
import scala.math.abs
2228
import scala.util.Random
2329

24-
import org.scalatest.FunSuite
25-
import org.jblas.DoubleMatrix
26-
27-
import org.apache.spark.SparkContext._
28-
import org.apache.spark.mllib.util.MLlibTestSparkContext
29-
import org.apache.spark.mllib.recommendation.ALS.BlockStats
30-
3130
object ALSSuite {
3231

3332
def generateRatingsAsJavaList(
@@ -139,6 +138,32 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext {
139138
assert(u11 != u2)
140139
}
141140

141+
test("Storage Level for RDDs in model") {
142+
val ratings = sc.parallelize(ALSSuite.generateRatings(10, 20, 5, 0.5, false, false)._1, 2)
143+
var storageLevel = StorageLevel.MEMORY_ONLY
144+
var model = new ALS()
145+
.setRank(5)
146+
.setIterations(1)
147+
.setLambda(1.0)
148+
.setBlocks(2)
149+
.setSeed(1)
150+
.setFinalRDDStorageLevel(storageLevel)
151+
.run(ratings)
152+
assert(model.productFeatures.getStorageLevel == storageLevel);
153+
assert(model.userFeatures.getStorageLevel == storageLevel);
154+
storageLevel = StorageLevel.DISK_ONLY
155+
model = new ALS()
156+
.setRank(5)
157+
.setIterations(1)
158+
.setLambda(1.0)
159+
.setBlocks(2)
160+
.setSeed(1)
161+
.setFinalRDDStorageLevel(storageLevel)
162+
.run(ratings)
163+
assert(model.productFeatures.getStorageLevel == storageLevel);
164+
assert(model.userFeatures.getStorageLevel == storageLevel);
165+
}
166+
142167
test("negative ids") {
143168
val data = ALSSuite.generateRatings(50, 50, 2, 0.7, false, false)
144169
val ratings = sc.parallelize(data._1.map { case Rating(u, p, r) =>

0 commit comments

Comments
 (0)