Skip to content

Commit a76da7b

Browse files
committed
update ALS tests
1 parent 2a8deb3 commit a76da7b

File tree

2 files changed

+59
-56
lines changed
  • mllib/src

2 files changed

+59
-56
lines changed

mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -643,7 +643,7 @@ private[recommendation] object ALS extends Logging {
643643
*/
644644
def compress(): InBlock = {
645645
val sz = size
646-
assert(sz > 0) // TODO: Check whether it is possible to have empty blocks.
646+
assert(sz > 0, "Empty in-link block should not exist.")
647647
sort()
648648
val uniqueSrcIdsBuilder = mutable.ArrayBuilder.make[Int]
649649
val dstCountsBuilder = mutable.ArrayBuilder.make[Int]
@@ -681,7 +681,7 @@ private[recommendation] object ALS extends Logging {
681681

682682
private def sort(): Unit = {
683683
val sz = size
684-
// Since there might be interleaved log messages, we insert a unqiue id for easy pairing.
684+
// Since there might be interleaved log messages, we insert a unique id for easy pairing.
685685
val sortId = Utils.random.nextInt()
686686
logDebug(s"Start sorting an uncompressed in-block of size $sz. (sortId = $sortId)")
687687
val start = System.nanoTime()
@@ -807,7 +807,7 @@ private[recommendation] object ALS extends Logging {
807807
i += 1
808808
}
809809
logDebug(
810-
"Converting to local indices took " + (System.nanoTime() - start) / 1e9 + "seconds.")
810+
"Converting to local indices took " + (System.nanoTime() - start) / 1e9 + " seconds.")
811811
val dstLocalIndices = dstIds.map(dstIdToLocalIndex.apply)
812812
(srcBlockId, (dstBlockId, srcIds, dstLocalIndices, ratings))
813813
}.groupByKey(new HashPartitioner(srcPart.numPartitions))
@@ -845,7 +845,7 @@ private[recommendation] object ALS extends Logging {
845845
}
846846

847847
/**
848-
* Compute dst factors by forming and solving least square problems.
848+
* Compute dst factors by constructing and solving least square problems.
849849
*
850850
* @param srcFactorBlocks src factors
851851
* @param srcOutBlocks src out-blocks
@@ -867,6 +867,7 @@ private[recommendation] object ALS extends Logging {
867867
srcEncoder: LocalIndexEncoder,
868868
implicitPrefs: Boolean = false,
869869
alpha: Double = 1.0): RDD[(Int, FactorBlock)] = {
870+
val numSrcBlocks = srcFactorBlocks.partitions.size
870871
val YtY = if (implicitPrefs) Some(computeYtY(srcFactorBlocks, rank)) else None
871872
val srcOut = srcOutBlocks.join(srcFactorBlocks).flatMap {
872873
case (srcBlockId, (srcOutBlock, srcFactors)) =>
@@ -877,7 +878,10 @@ private[recommendation] object ALS extends Logging {
877878
val merged = srcOut.groupByKey(new HashPartitioner(dstInBlocks.partitions.size))
878879
dstInBlocks.join(merged).mapValues {
879880
case (InBlock(dstIds, srcPtrs, srcEncodedIndices, ratings), srcFactors) =>
880-
val sortedSrcFactors = srcFactors.toSeq.sortBy(_._1).map(_._2).toArray
881+
val sortedSrcFactors = new Array[FactorBlock](numSrcBlocks)
882+
srcFactors.foreach { case (srcBlockId, factors) =>
883+
sortedSrcFactors(srcBlockId) = factors
884+
}
881885
val dstFactors = new Array[Array[Float]](dstIds.size)
882886
var j = 0
883887
val ls = new NormalEquation(rank)

mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala

Lines changed: 50 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,9 @@ import org.apache.spark.ml.recommendation.ALS._
3030
import org.apache.spark.mllib.linalg.Vectors
3131
import org.apache.spark.mllib.util.MLlibTestSparkContext
3232
import org.apache.spark.mllib.util.TestingUtils._
33+
import org.apache.spark.rdd.RDD
3334
import org.apache.spark.sql.{Row, SQLContext}
3435

35-
case class ALSTestData(
36-
training: Seq[Rating],
37-
test: Seq[Rating],
38-
userFactors: Map[Int, Array[Float]],
39-
itemFactors: Map[Int, Array[Float]])
40-
4136
class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
4237

4338
private var sqlContext: SQLContext = _
@@ -219,46 +214,61 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
219214
assert(decompressed.toSet === expected)
220215
}
221216

217+
/**
218+
* Generates ratings for testing ALS.
219+
*
220+
* @param numUsers number of users
221+
* @param numItems number of items
222+
* @param rank rank
223+
* @param trainingFraction fraction for training
224+
* @param testFraction fraction for test
225+
* @param noiseLevel noise level for additive Gaussian noise on training data
226+
* @param seed random seed
227+
* @return (training, test)
228+
*/
222229
def genALSTestData(
223230
numUsers: Int,
224231
numItems: Int,
225232
rank: Int,
226233
trainingFraction: Double,
227234
testFraction: Double,
228235
noiseLevel: Double = 0.0,
229-
seed: Long = 11L): ALSTestData = {
236+
seed: Long = 11L): (RDD[Rating], RDD[Rating]) = {
237+
val totalFraction = trainingFraction + testFraction
238+
require(totalFraction <= 1.0)
230239
val random = new Random(seed)
231240
val userFactors = genFactors(numUsers, rank, random)
232241
val itemFactors = genFactors(numItems, rank, random)
233-
val totalFraction = trainingFraction + testFraction
234242
val training = ArrayBuffer.empty[Rating]
235243
val test = ArrayBuffer.empty[Rating]
236244
for ((userId, userFactor) <- userFactors; (itemId, itemFactor) <- itemFactors) {
237245
val x = random.nextDouble()
238246
if (x < totalFraction) {
239247
val rating = blas.sdot(rank, userFactor, 1, itemFactor, 1)
240248
if (x < trainingFraction) {
241-
training += Rating(userId, itemId, rating + noiseLevel.toFloat * random.nextFloat())
249+
val noise = noiseLevel * random.nextGaussian()
250+
training += Rating(userId, itemId, rating + noise.toFloat)
242251
} else {
243252
test += Rating(userId, itemId, rating)
244253
}
245254
}
246255
}
247256
logInfo(s"Generated ${training.size} ratings for training and ${test.size} for test.")
248-
ALSTestData(training.toSeq, test.toSeq, userFactors, itemFactors)
257+
(sc.parallelize(training, 2), sc.parallelize(test, 2))
249258
}
250259

251-
def genFactors(size: Int, rank: Int, random: Random): Map[Int, Array[Float]] = {
260+
private def genFactors(size: Int, rank: Int, random: Random): Seq[(Int, Array[Float])] = {
252261
require(size > 0 && size < Int.MaxValue / 3)
253262
val ids = mutable.Set.empty[Int]
254263
while (ids.size < size) {
255264
ids += random.nextInt()
256265
}
257-
ids.map(id => (id, Array.fill(rank)(random.nextFloat()))).toMap
266+
ids.toSeq.sorted.map(id => (id, Array.fill(rank)(random.nextFloat())))
258267
}
259268

260269
def testALS(
261-
alsTestData: ALSTestData,
270+
training: RDD[Rating],
271+
test: RDD[Rating],
262272
rank: Int,
263273
maxIter: Int,
264274
regParam: Double,
@@ -267,8 +277,6 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
267277
numItemBlocks: Int = 3): Unit = {
268278
val sqlContext = this.sqlContext
269279
import sqlContext.{createSchemaRDD, symbolToUnresolvedAttribute}
270-
val training = sc.parallelize(alsTestData.training, 2)
271-
val test = sc.parallelize(alsTestData.test, 2)
272280
val als = new ALS()
273281
.setRank(rank)
274282
.setRegParam(regParam)
@@ -287,48 +295,39 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
287295
}
288296

289297
test("exact rank-1 matrix") {
290-
val testData = genALSTestData(
291-
numUsers = 20,
292-
numItems = 40,
293-
rank = 1,
294-
trainingFraction = 0.6,
295-
testFraction = 0.3)
296-
testALS(testData, maxIter = 1, rank = 1, regParam = 1e-4, targetRMSE = 0.002)
297-
testALS(testData, maxIter = 1, rank = 2, regParam = 1e-4, targetRMSE = 0.002)
298+
val (training, test) = genALSTestData(numUsers = 20, numItems = 40, rank = 1,
299+
trainingFraction = 0.6, testFraction = 0.3)
300+
testALS(training, test, maxIter = 1, rank = 1, regParam = 1e-5, targetRMSE = 0.001)
301+
testALS(training, test, maxIter = 1, rank = 2, regParam = 1e-5, targetRMSE = 0.001)
298302
}
299303

300304
test("approximate rank-1 matrix") {
301-
val testData = genALSTestData(
302-
numUsers = 20,
303-
numItems = 40,
304-
rank = 1,
305-
trainingFraction = 0.6,
306-
testFraction = 0.3,
307-
noiseLevel = 0.01)
308-
testALS(testData, maxIter = 2, rank = 1, regParam = 0.01, targetRMSE = 0.02)
309-
testALS(testData, maxIter = 2, rank = 2, regParam = 0.01, targetRMSE = 0.02)
305+
val (training, test) = genALSTestData(numUsers = 20, numItems = 40, rank = 1,
306+
trainingFraction = 0.6, testFraction = 0.3, noiseLevel = 0.01)
307+
testALS(training, test, maxIter = 2, rank = 1, regParam = 0.01, targetRMSE = 0.02)
308+
testALS(training, test, maxIter = 2, rank = 2, regParam = 0.01, targetRMSE = 0.02)
309+
}
310+
311+
test("approximate rank-2 matrix") {
312+
val (training, test) = genALSTestData(numUsers = 20, numItems = 40, rank = 2,
313+
trainingFraction = 0.6, testFraction = 0.3, noiseLevel = 0.01)
314+
testALS(training, test, maxIter = 4, rank = 2, regParam = 0.01, targetRMSE = 0.03)
315+
testALS(training, test, maxIter = 4, rank = 3, regParam = 0.01, targetRMSE = 0.03)
310316
}
311317

312-
test("exact rank-2 matrix") {
313-
val testData = genALSTestData(
314-
numUsers = 20,
315-
numItems = 40,
316-
rank = 2,
317-
trainingFraction = 0.6,
318-
testFraction = 0.3)
319-
testALS(testData, maxIter = 4, rank = 2, regParam = 1e-4, targetRMSE = 0.002)
320-
testALS(testData, maxIter = 6, rank = 3, regParam = 0.01, targetRMSE = 0.04)
318+
test("different block settings") {
319+
val (training, test) = genALSTestData(numUsers = 20, numItems = 40, rank = 2,
320+
trainingFraction = 0.6, testFraction = 0.3, noiseLevel = 0.01)
321+
for ((numUserBlocks, numItemBlocks) <- Seq((1, 1), (1, 2), (2, 1), (2, 2))) {
322+
testALS(training, test, maxIter = 4, rank = 2, regParam = 0.01, targetRMSE = 0.03,
323+
numUserBlocks = numUserBlocks, numItemBlocks = numItemBlocks)
324+
}
321325
}
322326

323-
test("approximate rank-2 matrix") {
324-
val testData = genALSTestData(
325-
numUsers = 20,
326-
numItems = 40,
327-
rank = 2,
328-
trainingFraction = 0.6,
329-
testFraction = 0.3,
330-
noiseLevel = 0.01)
331-
testALS(testData, maxIter = 4, rank = 2, regParam = 0.01, targetRMSE = 0.03)
332-
testALS(testData, maxIter = 4, rank = 3, regParam = 0.01, targetRMSE = 0.03)
327+
test("more blocks than ratings") {
328+
val (training, test) = genALSTestData(numUsers = 4, numItems = 4, rank = 1,
329+
trainingFraction = 0.7, testFraction = 0.3)
330+
testALS(training, test, maxIter = 2, rank = 1, regParam = 1e-4, targetRMSE = 0.002,
331+
numItemBlocks = 5, numUserBlocks = 5)
333332
}
334333
}

0 commit comments

Comments
 (0)