@@ -19,16 +19,33 @@ package org.apache.spark.ml.recommendation
1919
2020import java .util .Random
2121
22+ import scala .collection .mutable
23+ import scala .collection .mutable .ArrayBuffer
24+
25+ import com .github .fommil .netlib .BLAS .{getInstance => blas }
2226import org .scalatest .FunSuite
2327
28+ import org .apache .spark .Logging
2429import org .apache .spark .ml .recommendation .ALS ._
2530import org .apache .spark .mllib .linalg .Vectors
2631import org .apache .spark .mllib .util .MLlibTestSparkContext
2732import org .apache .spark .mllib .util .TestingUtils ._
33+ import org .apache .spark .sql .{Row , SQLContext }
2834
29- import scala .collection .mutable .ArrayBuffer
35+ case class ALSTestData (
36+ training : Seq [Rating ],
37+ test : Seq [Rating ],
38+ userFactors : Map [Int , Array [Float ]],
39+ itemFactors : Map [Int , Array [Float ]])
3040
31- class ALSSuite extends FunSuite with MLlibTestSparkContext {
41+ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
42+
43+ private var sqlContext : SQLContext = _
44+
45+ override def beforeAll (): Unit = {
46+ super .beforeAll()
47+ sqlContext = new SQLContext (sc)
48+ }
3249
3350 test(" LocalIndexEncoder" ) {
3451 val random = new Random
@@ -201,4 +218,117 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext {
201218 }
202219 assert(decompressed.toSet === expected)
203220 }
221+
222+ def genALSTestData (
223+ numUsers : Int ,
224+ numItems : Int ,
225+ rank : Int ,
226+ trainingFraction : Double ,
227+ testFraction : Double ,
228+ noiseLevel : Double = 0.0 ,
229+ seed : Long = 11L ): ALSTestData = {
230+ val random = new Random (seed)
231+ val userFactors = genFactors(numUsers, rank, random)
232+ val itemFactors = genFactors(numItems, rank, random)
233+ val totalFraction = trainingFraction + testFraction
234+ val training = ArrayBuffer .empty[Rating ]
235+ val test = ArrayBuffer .empty[Rating ]
236+ for ((userId, userFactor) <- userFactors; (itemId, itemFactor) <- itemFactors) {
237+ val x = random.nextDouble()
238+ if (x < totalFraction) {
239+ val rating = blas.sdot(rank, userFactor, 1 , itemFactor, 1 )
240+ if (x < trainingFraction) {
241+ training += Rating (userId, itemId, rating + noiseLevel.toFloat * random.nextFloat())
242+ } else {
243+ test += Rating (userId, itemId, rating)
244+ }
245+ }
246+ }
247+ logInfo(s " Generated ${training.size} ratings for training and ${test.size} for test. " )
248+ ALSTestData (training.toSeq, test.toSeq, userFactors, itemFactors)
249+ }
250+
251+ def genFactors (size : Int , rank : Int , random : Random ): Map [Int , Array [Float ]] = {
252+ require(size > 0 && size < Int .MaxValue / 3 )
253+ val ids = mutable.Set .empty[Int ]
254+ while (ids.size < size) {
255+ ids += random.nextInt()
256+ }
257+ ids.map(id => (id, Array .fill(rank)(random.nextFloat()))).toMap
258+ }
259+
260+ def testALS (
261+ alsTestData : ALSTestData ,
262+ rank : Int ,
263+ maxIter : Int ,
264+ regParam : Double ,
265+ targetRMSE : Double ,
266+ numUserBlocks : Int = 2 ,
267+ numItemBlocks : Int = 3 ): Unit = {
268+ val sqlContext = this .sqlContext
269+ import sqlContext .{createSchemaRDD , symbolToUnresolvedAttribute }
270+ val training = sc.parallelize(alsTestData.training, 2 )
271+ val test = sc.parallelize(alsTestData.test, 2 )
272+ val als = new ALS ()
273+ .setRank(rank)
274+ .setRegParam(regParam)
275+ .setNumUserBlocks(numUserBlocks)
276+ .setNumItemBlocks(numItemBlocks)
277+ val model = als.fit(training)
278+ val prediction = model.transform(test)
279+ val mse = prediction.select(' rating , ' prediction )
280+ .map { case Row (rating : Float , prediction : Float ) =>
281+ val err = rating.toDouble - prediction
282+ err * err
283+ }.mean()
284+ val rmse = math.sqrt(mse)
285+ logInfo(s " Test RMSE is $rmse. " )
286+ assert(rmse < targetRMSE)
287+ }
288+
289+ test(" exact rank-1 matrix" ) {
290+ val testData = genALSTestData(
291+ numUsers = 20 ,
292+ numItems = 40 ,
293+ rank = 1 ,
294+ trainingFraction = 0.6 ,
295+ testFraction = 0.3 )
296+ testALS(testData, maxIter = 1 , rank = 1 , regParam = 1e-4 , targetRMSE = 0.002 )
297+ testALS(testData, maxIter = 1 , rank = 2 , regParam = 1e-4 , targetRMSE = 0.002 )
298+ }
299+
300+ test(" approximate rank-1 matrix" ) {
301+ val testData = genALSTestData(
302+ numUsers = 20 ,
303+ numItems = 40 ,
304+ rank = 1 ,
305+ trainingFraction = 0.6 ,
306+ testFraction = 0.3 ,
307+ noiseLevel = 0.01 )
308+ testALS(testData, maxIter = 2 , rank = 1 , regParam = 0.01 , targetRMSE = 0.02 )
309+ testALS(testData, maxIter = 2 , rank = 2 , regParam = 0.01 , targetRMSE = 0.02 )
310+ }
311+
312+ test(" exact rank-2 matrix" ) {
313+ val testData = genALSTestData(
314+ numUsers = 20 ,
315+ numItems = 40 ,
316+ rank = 2 ,
317+ trainingFraction = 0.6 ,
318+ testFraction = 0.3 )
319+ testALS(testData, maxIter = 4 , rank = 2 , regParam = 1e-4 , targetRMSE = 0.002 )
320+ testALS(testData, maxIter = 6 , rank = 3 , regParam = 0.01 , targetRMSE = 0.04 )
321+ }
322+
323+ test(" approximate rank-2 matrix" ) {
324+ val testData = genALSTestData(
325+ numUsers = 20 ,
326+ numItems = 40 ,
327+ rank = 2 ,
328+ trainingFraction = 0.6 ,
329+ testFraction = 0.3 ,
330+ noiseLevel = 0.01 )
331+ testALS(testData, maxIter = 4 , rank = 2 , regParam = 0.01 , targetRMSE = 0.03 )
332+ testALS(testData, maxIter = 4 , rank = 3 , regParam = 0.01 , targetRMSE = 0.03 )
333+ }
204334}
0 commit comments