@@ -30,14 +30,9 @@ import org.apache.spark.ml.recommendation.ALS._
3030import org .apache .spark .mllib .linalg .Vectors
3131import org .apache .spark .mllib .util .MLlibTestSparkContext
3232import org .apache .spark .mllib .util .TestingUtils ._
33+ import org .apache .spark .rdd .RDD
3334import org .apache .spark .sql .{Row , SQLContext }
3435
35- case class ALSTestData (
36- training : Seq [Rating ],
37- test : Seq [Rating ],
38- userFactors : Map [Int , Array [Float ]],
39- itemFactors : Map [Int , Array [Float ]])
40-
4136class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
4237
4338 private var sqlContext : SQLContext = _
@@ -219,46 +214,61 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
219214 assert(decompressed.toSet === expected)
220215 }
221216
217+ /**
218+ * Generates ratings for testing ALS.
219+ *
220+ * @param numUsers number of users
221+ * @param numItems number of items
222+ * @param rank rank
223+ * @param trainingFraction fraction for training
224+ * @param testFraction fraction for test
225+ * @param noiseLevel noise level for additive Gaussian noise on training data
226+ * @param seed random seed
227+ * @return (training, test)
228+ */
222229 def genALSTestData (
223230 numUsers : Int ,
224231 numItems : Int ,
225232 rank : Int ,
226233 trainingFraction : Double ,
227234 testFraction : Double ,
228235 noiseLevel : Double = 0.0 ,
229- seed : Long = 11L ): ALSTestData = {
236+ seed : Long = 11L ): (RDD [Rating ], RDD [Rating ]) = {
237+ val totalFraction = trainingFraction + testFraction
238+ require(totalFraction <= 1.0 )
230239 val random = new Random (seed)
231240 val userFactors = genFactors(numUsers, rank, random)
232241 val itemFactors = genFactors(numItems, rank, random)
233- val totalFraction = trainingFraction + testFraction
234242 val training = ArrayBuffer .empty[Rating ]
235243 val test = ArrayBuffer .empty[Rating ]
236244 for ((userId, userFactor) <- userFactors; (itemId, itemFactor) <- itemFactors) {
237245 val x = random.nextDouble()
238246 if (x < totalFraction) {
239247 val rating = blas.sdot(rank, userFactor, 1 , itemFactor, 1 )
240248 if (x < trainingFraction) {
241- training += Rating (userId, itemId, rating + noiseLevel.toFloat * random.nextFloat())
249+ val noise = noiseLevel * random.nextGaussian()
250+ training += Rating (userId, itemId, rating + noise.toFloat)
242251 } else {
243252 test += Rating (userId, itemId, rating)
244253 }
245254 }
246255 }
247256 logInfo(s " Generated ${training.size} ratings for training and ${test.size} for test. " )
248- ALSTestData (training.toSeq, test.toSeq, userFactors, itemFactors )
257+ (sc.parallelize(training, 2 ), sc.parallelize(test, 2 ) )
249258 }
250259
251- def genFactors (size : Int , rank : Int , random : Random ): Map [ Int , Array [Float ]] = {
260+ private def genFactors (size : Int , rank : Int , random : Random ): Seq [( Int , Array [Float ]) ] = {
252261 require(size > 0 && size < Int .MaxValue / 3 )
253262 val ids = mutable.Set .empty[Int ]
254263 while (ids.size < size) {
255264 ids += random.nextInt()
256265 }
257- ids.map(id => (id, Array .fill(rank)(random.nextFloat()))).toMap
266+ ids.toSeq.sorted. map(id => (id, Array .fill(rank)(random.nextFloat())))
258267 }
259268
260269 def testALS (
261- alsTestData : ALSTestData ,
270+ training : RDD [Rating ],
271+ test : RDD [Rating ],
262272 rank : Int ,
263273 maxIter : Int ,
264274 regParam : Double ,
@@ -267,8 +277,6 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
267277 numItemBlocks : Int = 3 ): Unit = {
268278 val sqlContext = this .sqlContext
269279 import sqlContext .{createSchemaRDD , symbolToUnresolvedAttribute }
270- val training = sc.parallelize(alsTestData.training, 2 )
271- val test = sc.parallelize(alsTestData.test, 2 )
272280 val als = new ALS ()
273281 .setRank(rank)
274282 .setRegParam(regParam)
@@ -287,48 +295,39 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
287295 }
288296
289297 test(" exact rank-1 matrix" ) {
290- val testData = genALSTestData(
291- numUsers = 20 ,
292- numItems = 40 ,
293- rank = 1 ,
294- trainingFraction = 0.6 ,
295- testFraction = 0.3 )
296- testALS(testData, maxIter = 1 , rank = 1 , regParam = 1e-4 , targetRMSE = 0.002 )
297- testALS(testData, maxIter = 1 , rank = 2 , regParam = 1e-4 , targetRMSE = 0.002 )
298+ val (training, test) = genALSTestData(numUsers = 20 , numItems = 40 , rank = 1 ,
299+ trainingFraction = 0.6 , testFraction = 0.3 )
300+ testALS(training, test, maxIter = 1 , rank = 1 , regParam = 1e-5 , targetRMSE = 0.001 )
301+ testALS(training, test, maxIter = 1 , rank = 2 , regParam = 1e-5 , targetRMSE = 0.001 )
298302 }
299303
300304 test(" approximate rank-1 matrix" ) {
301- val testData = genALSTestData(
302- numUsers = 20 ,
303- numItems = 40 ,
304- rank = 1 ,
305- trainingFraction = 0.6 ,
306- testFraction = 0.3 ,
307- noiseLevel = 0.01 )
308- testALS(testData, maxIter = 2 , rank = 1 , regParam = 0.01 , targetRMSE = 0.02 )
309- testALS(testData, maxIter = 2 , rank = 2 , regParam = 0.01 , targetRMSE = 0.02 )
305+ val (training, test) = genALSTestData(numUsers = 20 , numItems = 40 , rank = 1 ,
306+ trainingFraction = 0.6 , testFraction = 0.3 , noiseLevel = 0.01 )
307+ testALS(training, test, maxIter = 2 , rank = 1 , regParam = 0.01 , targetRMSE = 0.02 )
308+ testALS(training, test, maxIter = 2 , rank = 2 , regParam = 0.01 , targetRMSE = 0.02 )
309+ }
310+
311+ test(" approximate rank-2 matrix" ) {
312+ val (training, test) = genALSTestData(numUsers = 20 , numItems = 40 , rank = 2 ,
313+ trainingFraction = 0.6 , testFraction = 0.3 , noiseLevel = 0.01 )
314+ testALS(training, test, maxIter = 4 , rank = 2 , regParam = 0.01 , targetRMSE = 0.03 )
315+ testALS(training, test, maxIter = 4 , rank = 3 , regParam = 0.01 , targetRMSE = 0.03 )
310316 }
311317
312- test(" exact rank-2 matrix" ) {
313- val testData = genALSTestData(
314- numUsers = 20 ,
315- numItems = 40 ,
316- rank = 2 ,
317- trainingFraction = 0.6 ,
318- testFraction = 0.3 )
319- testALS(testData, maxIter = 4 , rank = 2 , regParam = 1e-4 , targetRMSE = 0.002 )
320- testALS(testData, maxIter = 6 , rank = 3 , regParam = 0.01 , targetRMSE = 0.04 )
318+ test(" different block settings" ) {
319+ val (training, test) = genALSTestData(numUsers = 20 , numItems = 40 , rank = 2 ,
320+ trainingFraction = 0.6 , testFraction = 0.3 , noiseLevel = 0.01 )
321+ for ((numUserBlocks, numItemBlocks) <- Seq ((1 , 1 ), (1 , 2 ), (2 , 1 ), (2 , 2 ))) {
322+ testALS(training, test, maxIter = 4 , rank = 2 , regParam = 0.01 , targetRMSE = 0.03 ,
323+ numUserBlocks = numUserBlocks, numItemBlocks = numItemBlocks)
324+ }
321325 }
322326
323- test(" approximate rank-2 matrix" ) {
324- val testData = genALSTestData(
325- numUsers = 20 ,
326- numItems = 40 ,
327- rank = 2 ,
328- trainingFraction = 0.6 ,
329- testFraction = 0.3 ,
330- noiseLevel = 0.01 )
331- testALS(testData, maxIter = 4 , rank = 2 , regParam = 0.01 , targetRMSE = 0.03 )
332- testALS(testData, maxIter = 4 , rank = 3 , regParam = 0.01 , targetRMSE = 0.03 )
327+ test(" more blocks than ratings" ) {
328+ val (training, test) = genALSTestData(numUsers = 4 , numItems = 4 , rank = 1 ,
329+ trainingFraction = 0.7 , testFraction = 0.3 )
330+ testALS(training, test, maxIter = 2 , rank = 1 , regParam = 1e-4 , targetRMSE = 0.002 ,
331+ numItemBlocks = 5 , numUserBlocks = 5 )
333332 }
334333}
0 commit comments