Skip to content

Commit 2a8deb3

Browse files
committed
add some ALS tests
1 parent 857e876 commit 2a8deb3

File tree

2 files changed

+133
-3
lines changed

2 files changed

+133
-3
lines changed

mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala

Lines changed: 132 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,33 @@ package org.apache.spark.ml.recommendation
1919

2020
import java.util.Random
2121

22+
import scala.collection.mutable
23+
import scala.collection.mutable.ArrayBuffer
24+
25+
import com.github.fommil.netlib.BLAS.{getInstance => blas}
2226
import org.scalatest.FunSuite
2327

28+
import org.apache.spark.Logging
2429
import org.apache.spark.ml.recommendation.ALS._
2530
import org.apache.spark.mllib.linalg.Vectors
2631
import org.apache.spark.mllib.util.MLlibTestSparkContext
2732
import org.apache.spark.mllib.util.TestingUtils._
33+
import org.apache.spark.sql.{Row, SQLContext}
2834

29-
import scala.collection.mutable.ArrayBuffer
35+
case class ALSTestData(
36+
training: Seq[Rating],
37+
test: Seq[Rating],
38+
userFactors: Map[Int, Array[Float]],
39+
itemFactors: Map[Int, Array[Float]])
3040

31-
class ALSSuite extends FunSuite with MLlibTestSparkContext {
41+
class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
42+
43+
private var sqlContext: SQLContext = _
44+
45+
override def beforeAll(): Unit = {
46+
super.beforeAll()
47+
sqlContext = new SQLContext(sc)
48+
}
3249

3350
test("LocalIndexEncoder") {
3451
val random = new Random
@@ -201,4 +218,117 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext {
201218
}
202219
assert(decompressed.toSet === expected)
203220
}
221+
222+
def genALSTestData(
223+
numUsers: Int,
224+
numItems: Int,
225+
rank: Int,
226+
trainingFraction: Double,
227+
testFraction: Double,
228+
noiseLevel: Double = 0.0,
229+
seed: Long = 11L): ALSTestData = {
230+
val random = new Random(seed)
231+
val userFactors = genFactors(numUsers, rank, random)
232+
val itemFactors = genFactors(numItems, rank, random)
233+
val totalFraction = trainingFraction + testFraction
234+
val training = ArrayBuffer.empty[Rating]
235+
val test = ArrayBuffer.empty[Rating]
236+
for ((userId, userFactor) <- userFactors; (itemId, itemFactor) <- itemFactors) {
237+
val x = random.nextDouble()
238+
if (x < totalFraction) {
239+
val rating = blas.sdot(rank, userFactor, 1, itemFactor, 1)
240+
if (x < trainingFraction) {
241+
training += Rating(userId, itemId, rating + noiseLevel.toFloat * random.nextFloat())
242+
} else {
243+
test += Rating(userId, itemId, rating)
244+
}
245+
}
246+
}
247+
logInfo(s"Generated ${training.size} ratings for training and ${test.size} for test.")
248+
ALSTestData(training.toSeq, test.toSeq, userFactors, itemFactors)
249+
}
250+
251+
def genFactors(size: Int, rank: Int, random: Random): Map[Int, Array[Float]] = {
252+
require(size > 0 && size < Int.MaxValue / 3)
253+
val ids = mutable.Set.empty[Int]
254+
while (ids.size < size) {
255+
ids += random.nextInt()
256+
}
257+
ids.map(id => (id, Array.fill(rank)(random.nextFloat()))).toMap
258+
}
259+
260+
def testALS(
261+
alsTestData: ALSTestData,
262+
rank: Int,
263+
maxIter: Int,
264+
regParam: Double,
265+
targetRMSE: Double,
266+
numUserBlocks: Int = 2,
267+
numItemBlocks: Int = 3): Unit = {
268+
val sqlContext = this.sqlContext
269+
import sqlContext.{createSchemaRDD, symbolToUnresolvedAttribute}
270+
val training = sc.parallelize(alsTestData.training, 2)
271+
val test = sc.parallelize(alsTestData.test, 2)
272+
val als = new ALS()
273+
.setRank(rank)
274+
.setRegParam(regParam)
275+
.setNumUserBlocks(numUserBlocks)
276+
.setNumItemBlocks(numItemBlocks)
277+
val model = als.fit(training)
278+
val prediction = model.transform(test)
279+
val mse = prediction.select('rating, 'prediction)
280+
.map { case Row(rating: Float, prediction: Float) =>
281+
val err = rating.toDouble - prediction
282+
err * err
283+
}.mean()
284+
val rmse = math.sqrt(mse)
285+
logInfo(s"Test RMSE is $rmse.")
286+
assert(rmse < targetRMSE)
287+
}
288+
289+
test("exact rank-1 matrix") {
290+
val testData = genALSTestData(
291+
numUsers = 20,
292+
numItems = 40,
293+
rank = 1,
294+
trainingFraction = 0.6,
295+
testFraction = 0.3)
296+
testALS(testData, maxIter = 1, rank = 1, regParam = 1e-4, targetRMSE = 0.002)
297+
testALS(testData, maxIter = 1, rank = 2, regParam = 1e-4, targetRMSE = 0.002)
298+
}
299+
300+
test("approximate rank-1 matrix") {
301+
val testData = genALSTestData(
302+
numUsers = 20,
303+
numItems = 40,
304+
rank = 1,
305+
trainingFraction = 0.6,
306+
testFraction = 0.3,
307+
noiseLevel = 0.01)
308+
testALS(testData, maxIter = 2, rank = 1, regParam = 0.01, targetRMSE = 0.02)
309+
testALS(testData, maxIter = 2, rank = 2, regParam = 0.01, targetRMSE = 0.02)
310+
}
311+
312+
test("exact rank-2 matrix") {
313+
val testData = genALSTestData(
314+
numUsers = 20,
315+
numItems = 40,
316+
rank = 2,
317+
trainingFraction = 0.6,
318+
testFraction = 0.3)
319+
testALS(testData, maxIter = 4, rank = 2, regParam = 1e-4, targetRMSE = 0.002)
320+
testALS(testData, maxIter = 6, rank = 3, regParam = 0.01, targetRMSE = 0.04)
321+
}
322+
323+
test("approximate rank-2 matrix") {
324+
val testData = genALSTestData(
325+
numUsers = 20,
326+
numItems = 40,
327+
rank = 2,
328+
trainingFraction = 0.6,
329+
testFraction = 0.3,
330+
noiseLevel = 0.01)
331+
testALS(testData, maxIter = 4, rank = 2, regParam = 0.01, targetRMSE = 0.03)
332+
testALS(testData, maxIter = 4, rank = 3, regParam = 0.01, targetRMSE = 0.03)
333+
}
204334
}

mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext {
188188
* @param samplingRate what fraction of the user-product pairs are known
189189
* @param matchThreshold max difference allowed to consider a predicted rating correct
190190
* @param implicitPrefs flag to test implicit feedback
191-
* @param bulkPredict flag to test bulk prediciton
191+
* @param bulkPredict flag to test bulk predicition
192192
* @param negativeWeights whether the generated data can contain negative values
193193
* @param numUserBlocks number of user blocks to partition users into
194194
* @param numProductBlocks number of product blocks to partition products into

0 commit comments

Comments
 (0)