Skip to content

Commit 8e5db6a

Browse files
committed
Separating cleaning methods in standalone methods + Tests for cleaning correctness
1 parent 7af4945 commit 8e5db6a

File tree

2 files changed

+122
-35
lines changed

2 files changed

+122
-35
lines changed

mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala

Lines changed: 65 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -144,45 +144,13 @@ class PrefixSpan private (
144144
logInfo(s"minimum count for a frequent pattern: $minCount")
145145

146146
// Find frequent items.
147-
val freqItemAndCounts = data.flatMap { itemsets =>
148-
val uniqItems = mutable.Set.empty[Item]
149-
itemsets.foreach { _.foreach { item =>
150-
uniqItems += item
151-
}}
152-
uniqItems.toIterator.map((_, 1L))
153-
}.reduceByKey(_ + _)
154-
.filter { case (_, count) =>
155-
count >= minCount
156-
}.collect()
157-
val freqItems = freqItemAndCounts.sortBy(-_._2).map(_._1)
147+
val freqItems = findFrequentItems(data, minCount)
158148
logInfo(s"number of frequent items: ${freqItems.length}")
159149

160150
// Keep only frequent items from input sequences and convert them to internal storage.
161151
val itemToInt = freqItems.zipWithIndex.toMap
162-
val dataInternalRepr = data.flatMap { itemsets =>
163-
val allItems = mutable.ArrayBuilder.make[Int]
164-
var containsFreqItems = false
165-
allItems += 0
166-
itemsets.foreach { itemsets =>
167-
val items = mutable.ArrayBuilder.make[Int]
168-
itemsets.foreach { item =>
169-
if (itemToInt.contains(item)) {
170-
items += itemToInt(item) + 1 // using 1-indexing in internal format
171-
}
172-
}
173-
val result = items.result()
174-
if (result.nonEmpty) {
175-
containsFreqItems = true
176-
allItems ++= result.sorted
177-
allItems += 0
178-
}
179-
}
180-
if (containsFreqItems) {
181-
Iterator.single(allItems.result())
182-
} else {
183-
Iterator.empty
184-
}
185-
}.persist(StorageLevel.MEMORY_AND_DISK)
152+
val dataInternalRepr = toDatabaseInternalRepr(data, itemToInt)
153+
.persist(StorageLevel.MEMORY_AND_DISK)
186154

187155
val results = genFreqPatterns(dataInternalRepr, minCount, maxPatternLength, maxLocalProjDBSize)
188156

@@ -231,6 +199,68 @@ class PrefixSpan private (
231199
@Since("1.5.0")
232200
object PrefixSpan extends Logging {
233201

202+
/**
203+
* This methods finds all frequent items in a input dataset.
204+
*
205+
* @param data Sequences of itemsets.
206+
* @param minCount The minimal number of sequence an item should be present in to be frequent
207+
*
208+
* @return An array of Item containing only frequent items.
209+
*/
210+
private[fpm] def findFrequentItems[Item: ClassTag](data : RDD[Array[Array[Item]]],
211+
minCount : Long): Array[Item] = {
212+
213+
data.flatMap { itemsets =>
214+
val uniqItems = mutable.Set.empty[Item]
215+
itemsets.foreach { _.foreach { item =>
216+
uniqItems += item
217+
}}
218+
uniqItems.toIterator.map((_, 1L))
219+
}.reduceByKey(_ + _).filter { case (_, count) =>
220+
count >= minCount
221+
}.sortBy(-_._2).map(_._1).collect()
222+
}
223+
224+
/**
225+
* This methods cleans the input dataset from un-frequent items, and translate it's item
226+
* to their corresponding Int identifier.
227+
*
228+
* @param data Sequences of itemsets.
229+
* @param itemToInt A map allowing translation of frequent Items to their Int Identifier.
230+
* The map should only contain frequent item.
231+
*
232+
* @return The internal repr of the inputted dataset. With properly placed zero delimiter.
233+
*/
234+
private[fpm] def toDatabaseInternalRepr[Item: ClassTag](data : RDD[Array[Array[Item]]],
235+
itemToInt : Map[Item, Int]):
236+
RDD[Array[Int]] = {
237+
238+
data.flatMap { itemsets =>
239+
val allItems = mutable.ArrayBuilder.make[Int]
240+
var containsFreqItems = false
241+
allItems += 0
242+
itemsets.foreach { itemsets =>
243+
val items = mutable.ArrayBuilder.make[Int]
244+
itemsets.foreach { item =>
245+
if (itemToInt.contains(item)) {
246+
items += itemToInt(item) + 1 // using 1-indexing in internal format
247+
}
248+
}
249+
val result = items.result()
250+
if (result.nonEmpty) {
251+
containsFreqItems = true
252+
allItems ++= result.sorted
253+
allItems += 0
254+
}
255+
}
256+
if (containsFreqItems) {
257+
Iterator.single(allItems.result())
258+
} else {
259+
Iterator.empty
260+
}
261+
}
262+
}
263+
234264
/**
235265
* Find the complete set of frequent sequential patterns in the input sequences.
236266
* @param data ordered sequences of itemsets. We represent a sequence internally as Array[Int],

mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,55 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext {
360360
compareResults(expected, model.freqSequences.collect())
361361
}
362362

363+
test("PrefixSpan pre-processing's cleaning test") {
364+
365+
// One item per itemSet
366+
val itemToInt1 = (4 to 5).zipWithIndex.toMap
367+
val sequences1 = Seq(
368+
Array(Array(4), Array(1), Array(2), Array(5), Array(2), Array(4), Array(5)),
369+
Array(Array(6), Array(7), Array(8)))
370+
val rdd1 = sc.parallelize(sequences1, 2).cache()
371+
372+
val cleanedSequence1 = PrefixSpan.toDatabaseInternalRepr(rdd1, itemToInt1).collect()
373+
374+
val expected1 = Array(Array(0, 4, 0, 5, 0, 4, 0, 5, 0))
375+
.map(x => x.map(y => {
376+
if (y == 0) 0
377+
else itemToInt1(y) + 1
378+
}))
379+
380+
compareInternalSequences(expected1, cleanedSequence1)
381+
382+
// Multi-item sequence
383+
val itemToInt2 = (4 to 6).zipWithIndex.toMap
384+
val sequences2 = Seq(
385+
Array(Array(4, 5), Array(1, 6, 2), Array(2), Array(5), Array(2), Array(4), Array(5, 6, 7)),
386+
Array(Array(8, 9), Array(1, 2)))
387+
val rdd2 = sc.parallelize(sequences2, 2).cache()
388+
389+
val cleanedSequence2 = PrefixSpan.toDatabaseInternalRepr(rdd2, itemToInt2).collect()
390+
391+
val expected2 = Array(Array(0, 4, 5, 0, 6, 0, 5, 0, 4, 0, 5, 6, 0))
392+
.map(x => x.map(y => {
393+
if (y == 0) 0
394+
else itemToInt2(y) + 1
395+
}))
396+
397+
compareInternalSequences(expected2, cleanedSequence2)
398+
399+
// Emptied sequence
400+
val itemToInt3 = (10 to 10).zipWithIndex.toMap
401+
val sequences3 = Seq(
402+
Array(Array(4, 5), Array(1, 6, 2), Array(2), Array(5), Array(2), Array(4), Array(5, 6, 7)),
403+
Array(Array(8, 9), Array(1, 2)))
404+
val rdd3 = sc.parallelize(sequences3, 2).cache()
405+
406+
val cleanedSequence3 = PrefixSpan.toDatabaseInternalRepr(rdd3, itemToInt3).collect()
407+
val expected3: Array[Array[Int]] = Array()
408+
409+
compareInternalSequences(expected3, cleanedSequence3)
410+
}
411+
363412
test("model save/load") {
364413
val sequences = Seq(
365414
Array(Array(1, 2), Array(3)),
@@ -409,4 +458,12 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext {
409458
val actualSet = actualValue.map(x => (x._1.toSeq, x._2)).toSet
410459
assert(expectedSet === actualSet)
411460
}
461+
462+
private def compareInternalSequences(
463+
expectedValue: Array[Array[Int]],
464+
actualValue: Array[Array[Int]]): Unit = {
465+
val expectedSet = expectedValue.map(x => x.toSeq).toSet
466+
val actualSet = actualValue.map(x => x.toSeq).toSet
467+
assert(expectedSet === actualSet)
468+
}
412469
}

0 commit comments

Comments
 (0)