1717
1818package org .apache .spark .mllib .fpm
1919
20+ import scala .collection .mutable .ArrayBuffer
21+
2022import org .apache .spark .Logging
2123import org .apache .spark .annotation .Experimental
2224import org .apache .spark .rdd .RDD
@@ -43,7 +45,7 @@ class PrefixSpan private (
4345 private var minSupport : Double ,
4446 private var maxPatternLength : Int ) extends Logging with Serializable {
4547
46- private val minPatternsBeforeShuffle : Int = 20
48+ private val minPatternsBeforeLocalProcessing : Int = 20
4749
4850 /**
4951 * Constructs a default instance with default parameters
@@ -88,66 +90,65 @@ class PrefixSpan private (
8890 val prefixSuffixPairs = getPrefixSuffixPairs(
8991 lengthOnePatternsAndCounts.map(_._1).collect(), sequences)
9092 var patternsCount : Long = lengthOnePatternsAndCounts.count()
91- var allPatternAndCounts = lengthOnePatternsAndCounts.map(x => (Array (x._1), x._2))
93+ var allPatternAndCounts = lengthOnePatternsAndCounts.map(x => (ArrayBuffer (x._1), x._2))
9294 var currentPrefixSuffixPairs = prefixSuffixPairs
93- while (patternsCount <= minPatternsBeforeShuffle && currentPrefixSuffixPairs.count() != 0 ) {
95+ var patternLength : Int = 1
96+ while (patternLength < maxPatternLength &&
97+ patternsCount <= minPatternsBeforeLocalProcessing &&
98+ currentPrefixSuffixPairs.count() != 0 ) {
9499 val (nextPatternAndCounts, nextPrefixSuffixPairs) =
95100 getPatternCountsAndPrefixSuffixPairs(minCount, currentPrefixSuffixPairs)
96- patternsCount = nextPatternAndCounts.count().toInt
101+ patternsCount = nextPatternAndCounts.count()
97102 currentPrefixSuffixPairs = nextPrefixSuffixPairs
98103 allPatternAndCounts = allPatternAndCounts ++ nextPatternAndCounts
104+ patternLength = patternLength + 1
99105 }
100- if (patternsCount > 0 ) {
106+ if (patternLength < maxPatternLength && patternsCount > 0 ) {
101107 val projectedDatabase = currentPrefixSuffixPairs
102108 .map(x => (x._1.toSeq, x._2))
103109 .groupByKey()
104110 .map(x => (x._1.toArray, x._2.toArray))
105111 val nextPatternAndCounts = getPatternsInLocal(minCount, projectedDatabase)
106112 allPatternAndCounts = allPatternAndCounts ++ nextPatternAndCounts
107113 }
108- allPatternAndCounts
114+ allPatternAndCounts.map { case (pattern, count) => (pattern.toArray, count) }
109115 }
110116
111117 /**
112118 * Get the pattern and counts, and prefix suffix pairs
113119 * @param minCount minimum count
114- * @param prefixSuffixPairs prefix and suffix pairs,
115- * @return pattern and counts, and prefix suffix pairs
116- * (Array [pattern, count], RDD[prefix, suffix ])
120+ * @param prefixSuffixPairs prefix (length n) and suffix pairs,
121+ * @return pattern (length n+1) and counts, and prefix (length n+1) and suffix pairs
122+ * (RDD [pattern, count], RDD[prefix, suffix ])
117123 */
118124 private def getPatternCountsAndPrefixSuffixPairs (
119125 minCount : Long ,
120- prefixSuffixPairs : RDD [(Array [Int ], Array [Int ])]):
121- (RDD [(Array [Int ], Long )], RDD [(Array [Int ], Array [Int ])]) = {
122- val prefixAndFreqentItemAndCounts = prefixSuffixPairs
123- .flatMap { case (prefix, suffix) =>
124- suffix.distinct.map(y => ((prefix.toSeq, y), 1L ))
125- }.reduceByKey(_ + _)
126+ prefixSuffixPairs : RDD [(ArrayBuffer [Int ], Array [Int ])]):
127+ (RDD [(ArrayBuffer [Int ], Long )], RDD [(ArrayBuffer [Int ], Array [Int ])]) = {
128+ val prefixAndFrequentItemAndCounts = prefixSuffixPairs
129+ .flatMap { case (prefix, suffix) => suffix.distinct.map(y => ((prefix, y), 1L )) }
130+ .reduceByKey(_ + _)
126131 .filter(_._2 >= minCount)
127- val patternAndCounts = prefixAndFreqentItemAndCounts
128- .map{ case ((prefix, item), count) => (prefix.toArray :+ item, count) }
129- val prefixlength = prefixSuffixPairs.first()._1.length
130- if (prefixlength + 1 >= maxPatternLength) {
131- (patternAndCounts, prefixSuffixPairs.filter(x => false ))
132- } else {
133- val frequentItemsMap = prefixAndFreqentItemAndCounts
134- .keys
135- .groupByKey()
136- .mapValues(_.toSet)
137- .collect
138- .toMap
139- val nextPrefixSuffixPairs = prefixSuffixPairs
140- .filter(x => frequentItemsMap.contains(x._1))
141- .flatMap { case (prefix, suffix) =>
142- val frequentItemSet = frequentItemsMap(prefix)
143- val filteredSuffix = suffix.filter(frequentItemSet.contains(_))
144- val nextSuffixes = frequentItemSet.map{ item =>
145- (item, LocalPrefixSpan .getSuffix(item, filteredSuffix))
146- }.filter(_._2.nonEmpty)
147- nextSuffixes.map { case (item, suffix) => (prefix :+ item, suffix) }
132+ val patternAndCounts = prefixAndFrequentItemAndCounts
133+ .map { case ((prefix, item), count) => (prefix :+ item, count) }
134+ val prefixToFrequentNextItemsMap = prefixAndFrequentItemAndCounts
135+ .keys
136+ .groupByKey()
137+ .mapValues(_.toSet)
138+ .collect()
139+ .toMap
140+ val nextPrefixSuffixPairs = prefixSuffixPairs
141+ .filter(x => prefixToFrequentNextItemsMap.contains(x._1))
142+ .flatMap { case (prefix, suffix) =>
143+ val frequentNextItems = prefixToFrequentNextItemsMap(prefix)
144+ val filteredSuffix = suffix.filter(frequentNextItems.contains(_))
145+ frequentNextItems.flatMap { item =>
146+ val suffix = LocalPrefixSpan .getSuffix(item, filteredSuffix)
147+ if (suffix.isEmpty) None
148+ else Some (prefix :+ item, suffix)
148149 }
149- (patternAndCounts, nextPrefixSuffixPairs)
150150 }
151+ (patternAndCounts, nextPrefixSuffixPairs)
151152 }
152153
153154 /**
@@ -181,14 +182,14 @@ class PrefixSpan private (
181182 */
182183 private def getPrefixSuffixPairs (
183184 frequentPrefixes : Array [Int ],
184- sequences : RDD [Array [Int ]]): RDD [(Array [Int ], Array [Int ])] = {
185+ sequences : RDD [Array [Int ]]): RDD [(ArrayBuffer [Int ], Array [Int ])] = {
185186 val filteredSequences = sequences.map { p =>
186187 p.filter (frequentPrefixes.contains(_) )
187188 }
188189 filteredSequences.flatMap { x =>
189190 frequentPrefixes.map { y =>
190191 val sub = LocalPrefixSpan .getSuffix(y, x)
191- (Array (y), sub)
192+ (ArrayBuffer (y), sub)
192193 }.filter(_._2.nonEmpty)
193194 }
194195 }
@@ -201,9 +202,9 @@ class PrefixSpan private (
201202 */
202203 private def getPatternsInLocal (
203204 minCount : Long ,
204- data : RDD [(Array [Int ], Array [Array [Int ]])]): RDD [(Array [Int ], Long )] = {
205- data.flatMap { x =>
206- LocalPrefixSpan .run(minCount, maxPatternLength, x._1, x._2)
207- }
205+ data : RDD [(Array [Int ], Array [Array [Int ]])]): RDD [(ArrayBuffer [Int ], Long )] = {
206+ data
207+ .flatMap { x => LocalPrefixSpan .run(minCount, maxPatternLength, x._1, x._2) }
208+ .map { case (pattern, count) => (pattern.to[ ArrayBuffer ], count) }
208209 }
209210}
0 commit comments