@@ -103,45 +103,49 @@ class PrefixSpan private (
103103 // Convert min support to a min number of transactions for this dataset
104104 val minCount = if (minSupport == 0 ) 0L else math.ceil(sequences.count() * minSupport).toLong
105105
106- val itemCounts = sequences
106+ // Frequent items -> number of occurrences, all items here satisfy the `minSupport` threshold
107+ val freqItemCounts = sequences
107108 .flatMap(seq => seq.distinct.map(item => (item, 1L )))
108109 .reduceByKey(_ + _)
109110 .filter(_._2 >= minCount)
110- var allPatternAndCounts = itemCounts.map(x => (List (x._1), x._2))
111111
112- val prefixSuffixPairs = {
113- val frequentItems = itemCounts.map(_._1).collect()
114- val candidates = sequences.map { p =>
115- p.filter (frequentItems.contains(_) )
116- }
117- candidates.flatMap { x =>
118- frequentItems.map { y =>
119- val sub = LocalPrefixSpan .getSuffix(y, x)
120- (List (y), sub)
121- }.filter(_._2.nonEmpty)
112+ // Pairs of (length 1 prefix, suffix consisting of frequent items)
113+ val itemSuffixPairs = {
114+ val freqItems = freqItemCounts.keys.collect().toSet
115+ sequences.flatMap { seq =>
116+ freqItems.flatMap { item =>
117+ val candidateSuffix = LocalPrefixSpan .getSuffix(item, seq.filter(freqItems.contains(_)))
118+ candidateSuffix match {
119+ case suffix if ! suffix.isEmpty => Some ((List (item), suffix))
120+ case _ => None
121+ }
122+ }
122123 }
123124 }
124- var (smallPrefixSuffixPairs, largePrefixSuffixPairs) = partitionByProjDBSize(prefixSuffixPairs)
125125
126- while (largePrefixSuffixPairs.count() != 0 ) {
126+ // Accumulator for the computed results to be returned
127+ var resultsAccumulator = freqItemCounts.map(x => (List (x._1), x._2))
128+
129+ // Remaining work to be locally and distributively processed respectfully
130+ var (pairsForLocal, pairsForDistributed) = partitionByProjDBSize(itemSuffixPairs)
131+
132+ // Continue processing until no pairs for distributed processing remain (i.e. all prefixes have
133+ // projected database sizes <= `maxLocalProjDBSize`)
134+ while (pairsForDistributed.count() != 0 ) {
127135 val (nextPatternAndCounts, nextPrefixSuffixPairs) =
128- getPatternCountsAndPrefixSuffixPairs(minCount, largePrefixSuffixPairs )
129- largePrefixSuffixPairs .unpersist()
136+ getPatternCountsAndPrefixSuffixPairs(minCount, pairsForDistributed )
137+ pairsForDistributed .unpersist()
130138 val (smallerPairsPart, largerPairsPart) = partitionByProjDBSize(nextPrefixSuffixPairs)
131- largePrefixSuffixPairs = largerPairsPart
132- largePrefixSuffixPairs .persist(StorageLevel .MEMORY_AND_DISK )
133- smallPrefixSuffixPairs ++= smallerPairsPart
134- allPatternAndCounts ++= nextPatternAndCounts
139+ pairsForDistributed = largerPairsPart
140+ pairsForDistributed .persist(StorageLevel .MEMORY_AND_DISK )
141+ pairsForLocal ++= smallerPairsPart
142+ resultsAccumulator ++= nextPatternAndCounts
135143 }
136144
137- if (smallPrefixSuffixPairs.count() > 0 ) {
138- val projectedDatabase = smallPrefixSuffixPairs
139- // TODO aggregateByKey
140- .groupByKey()
141- val nextPatternAndCounts = getPatternsInLocal(minCount, projectedDatabase)
142- allPatternAndCounts ++= nextPatternAndCounts
143- }
144- allPatternAndCounts.map { case (pattern, count) => (pattern.toArray, count) }
145+ // Process the small projected databases locally
146+ resultsAccumulator ++= getPatternsInLocal(minCount, pairsForLocal.groupByKey())
147+
148+ resultsAccumulator.map { case (pattern, count) => (pattern.toArray, count) }
145149 }
146150
147151
@@ -177,8 +181,8 @@ class PrefixSpan private (
177181 */
178182 private def getPatternCountsAndPrefixSuffixPairs (
179183 minCount : Long ,
180- prefixSuffixPairs : RDD [(List [Int ], Array [Int ])]):
181- (RDD [(List [Int ], Long )], RDD [(List [Int ], Array [Int ])]) = {
184+ prefixSuffixPairs : RDD [(List [Int ], Array [Int ])])
185+ : (RDD [(List [Int ], Long )], RDD [(List [Int ], Array [Int ])]) = {
182186 val prefixAndFrequentItemAndCounts = prefixSuffixPairs
183187 .flatMap { case (prefix, suffix) => suffix.distinct.map(y => ((prefix, y), 1L )) }
184188 .reduceByKey(_ + _)
0 commit comments