1717
1818package org .apache .spark .mllib .fpm
1919
20+ import scala .collection .mutable .ArrayBuffer
21+
2022import org .apache .spark .Logging
2123import org .apache .spark .annotation .Experimental
2224import org .apache .spark .rdd .RDD
@@ -43,28 +45,45 @@ class PrefixSpan private (
4345 private var minSupport : Double ,
4446 private var maxPatternLength : Int ) extends Logging with Serializable {
4547
48+ /**
49+ * The maximum number of items allowed in a projected database before local processing. If a
50+ * projected database exceeds this size, another iteration of distributed PrefixSpan is run.
51+ */
52+ // TODO: make configurable with a better default value, 10000 may be too small
53+ private val maxLocalProjDBSize : Long = 10000
54+
4655 /**
4756 * Constructs a default instance with default parameters
4857 * {minSupport: `0.1`, maxPatternLength: `10`}.
4958 */
5059 def this () = this (0.1 , 10 )
5160
61+ /**
62+ * Get the minimal support (i.e. the frequency of occurrence before a pattern is considered
63+ * frequent).
64+ */
65+ def getMinSupport : Double = this .minSupport
66+
5267 /**
5368 * Sets the minimal support level (default: `0.1`).
5469 */
5570 def setMinSupport (minSupport : Double ): this .type = {
56- require(minSupport >= 0 && minSupport <= 1 ,
57- " The minimum support value must be between 0 and 1, including 0 and 1." )
71+ require(minSupport >= 0 && minSupport <= 1 , " The minimum support value must be in [0, 1]." )
5872 this .minSupport = minSupport
5973 this
6074 }
6175
76+ /**
77+ * Gets the maximal pattern length (i.e. the length of the longest sequential pattern to consider.
78+ */
79+ def getMaxPatternLength : Double = this .maxPatternLength
80+
6281 /**
6382 * Sets maximal pattern length (default: `10`).
6483 */
6584 def setMaxPatternLength (maxPatternLength : Int ): this .type = {
66- require( maxPatternLength >= 1 ,
67- " The maximum pattern length value must be greater than 0." )
85+ // TODO: support unbounded pattern length when maxPatternLength = 0
86+ require(maxPatternLength >= 1 , " The maximum pattern length value must be greater than 0." )
6887 this .maxPatternLength = maxPatternLength
6988 this
7089 }
@@ -78,81 +97,153 @@ class PrefixSpan private (
7897 * the value of pair is the pattern's count.
7998 */
8099 def run (sequences : RDD [Array [Int ]]): RDD [(Array [Int ], Long )] = {
100+ val sc = sequences.sparkContext
101+
81102 if (sequences.getStorageLevel == StorageLevel .NONE ) {
82103 logWarning(" Input data is not cached." )
83104 }
84- val minCount = getMinCount(sequences)
85- val lengthOnePatternsAndCounts =
86- getFreqItemAndCounts(minCount, sequences).collect()
87- val prefixAndProjectedDatabase = getPrefixAndProjectedDatabase(
88- lengthOnePatternsAndCounts.map(_._1), sequences)
89- val groupedProjectedDatabase = prefixAndProjectedDatabase
90- .map(x => (x._1.toSeq, x._2))
91- .groupByKey()
92- .map(x => (x._1.toArray, x._2.toArray))
93- val nextPatterns = getPatternsInLocal(minCount, groupedProjectedDatabase)
94- val lengthOnePatternsAndCountsRdd =
95- sequences.sparkContext.parallelize(
96- lengthOnePatternsAndCounts.map(x => (Array (x._1), x._2)))
97- val allPatterns = lengthOnePatternsAndCountsRdd ++ nextPatterns
98- allPatterns
105+
106+ // Convert min support to a min number of transactions for this dataset
107+ val minCount = if (minSupport == 0 ) 0L else math.ceil(sequences.count() * minSupport).toLong
108+
109+ // (Frequent items -> number of occurrences, all items here satisfy the `minSupport` threshold
110+ val freqItemCounts = sequences
111+ .flatMap(seq => seq.distinct.map(item => (item, 1L )))
112+ .reduceByKey(_ + _)
113+ .filter(_._2 >= minCount)
114+ .collect()
115+
116+ // Pairs of (length 1 prefix, suffix consisting of frequent items)
117+ val itemSuffixPairs = {
118+ val freqItems = freqItemCounts.map(_._1).toSet
119+ sequences.flatMap { seq =>
120+ val filteredSeq = seq.filter(freqItems.contains(_))
121+ freqItems.flatMap { item =>
122+ val candidateSuffix = LocalPrefixSpan .getSuffix(item, filteredSeq)
123+ candidateSuffix match {
124+ case suffix if ! suffix.isEmpty => Some ((List (item), suffix))
125+ case _ => None
126+ }
127+ }
128+ }
129+ }
130+
131+ // Accumulator for the computed results to be returned, initialized to the frequent items (i.e.
132+ // frequent length-one prefixes)
133+ var resultsAccumulator = freqItemCounts.map(x => (List (x._1), x._2))
134+
135+ // Remaining work to be locally and distributively processed respectfully
136+ var (pairsForLocal, pairsForDistributed) = partitionByProjDBSize(itemSuffixPairs)
137+
138+ // Continue processing until no pairs for distributed processing remain (i.e. all prefixes have
139+ // projected database sizes <= `maxLocalProjDBSize`)
140+ while (pairsForDistributed.count() != 0 ) {
141+ val (nextPatternAndCounts, nextPrefixSuffixPairs) =
142+ extendPrefixes(minCount, pairsForDistributed)
143+ pairsForDistributed.unpersist()
144+ val (smallerPairsPart, largerPairsPart) = partitionByProjDBSize(nextPrefixSuffixPairs)
145+ pairsForDistributed = largerPairsPart
146+ pairsForDistributed.persist(StorageLevel .MEMORY_AND_DISK )
147+ pairsForLocal ++= smallerPairsPart
148+ resultsAccumulator ++= nextPatternAndCounts.collect()
149+ }
150+
151+ // Process the small projected databases locally
152+ val remainingResults = getPatternsInLocal(
153+ minCount, sc.parallelize(pairsForLocal, 1 ).groupByKey())
154+
155+ (sc.parallelize(resultsAccumulator, 1 ) ++ remainingResults)
156+ .map { case (pattern, count) => (pattern.toArray, count) }
99157 }
100158
159+
101160 /**
102- * Get the minimum count (sequences count * minSupport).
103- * @param sequences input data set, contains a set of sequences,
104- * @return minimum count,
161+ * Partitions the prefix-suffix pairs by projected database size.
162+ * @param prefixSuffixPairs prefix (length n) and suffix pairs,
163+ * @return prefix-suffix pairs partitioned by whether their projected database size is <= or
164+ * greater than [[maxLocalProjDBSize ]]
105165 */
106- private def getMinCount (sequences : RDD [Array [Int ]]): Long = {
107- if (minSupport == 0 ) 0L else math.ceil(sequences.count() * minSupport).toLong
166+ private def partitionByProjDBSize (prefixSuffixPairs : RDD [(List [Int ], Array [Int ])])
167+ : (Array [(List [Int ], Array [Int ])], RDD [(List [Int ], Array [Int ])]) = {
168+ val prefixToSuffixSize = prefixSuffixPairs
169+ .aggregateByKey(0 )(
170+ seqOp = { case (count, suffix) => count + suffix.length },
171+ combOp = { _ + _ })
172+ val smallPrefixes = prefixToSuffixSize
173+ .filter(_._2 <= maxLocalProjDBSize)
174+ .keys
175+ .collect()
176+ .toSet
177+ val small = prefixSuffixPairs.filter { case (prefix, _) => smallPrefixes.contains(prefix) }
178+ val large = prefixSuffixPairs.filter { case (prefix, _) => ! smallPrefixes.contains(prefix) }
179+ (small.collect(), large)
108180 }
109181
110182 /**
111- * Generates frequent items by filtering the input data using minimal count level.
112- * @param minCount the absolute minimum count
113- * @param sequences original sequences data
114- * @return array of item and count pair
183+ * Extends all prefixes by one item from their suffix and computes the resulting frequent prefixes
184+ * and remaining work.
185+ * @param minCount minimum count
186+ * @param prefixSuffixPairs prefix (length N) and suffix pairs,
187+ * @return (frequent length N+1 extended prefix, count) pairs and (frequent length N+1 extended
188+ * prefix, corresponding suffix) pairs.
115189 */
116- private def getFreqItemAndCounts (
190+ private def extendPrefixes (
117191 minCount : Long ,
118- sequences : RDD [Array [Int ]]): RDD [(Int , Long )] = {
119- sequences.flatMap(_.distinct.map((_, 1L )))
192+ prefixSuffixPairs : RDD [(List [Int ], Array [Int ])])
193+ : (RDD [(List [Int ], Long )], RDD [(List [Int ], Array [Int ])]) = {
194+
195+ // (length N prefix, item from suffix) pairs and their corresponding number of occurrences
196+ // Every (prefix :+ suffix) is guaranteed to have support exceeding `minSupport`
197+ val prefixItemPairAndCounts = prefixSuffixPairs
198+ .flatMap { case (prefix, suffix) => suffix.distinct.map(y => ((prefix, y), 1L )) }
120199 .reduceByKey(_ + _)
121200 .filter(_._2 >= minCount)
122- }
123201
124- /**
125- * Get the frequent prefixes' projected database.
126- * @param frequentPrefixes frequent prefixes
127- * @param sequences sequences data
128- * @return prefixes and projected database
129- */
130- private def getPrefixAndProjectedDatabase (
131- frequentPrefixes : Array [Int ],
132- sequences : RDD [Array [Int ]]): RDD [(Array [Int ], Array [Int ])] = {
133- val filteredSequences = sequences.map { p =>
134- p.filter (frequentPrefixes.contains(_) )
135- }
136- filteredSequences.flatMap { x =>
137- frequentPrefixes.map { y =>
138- val sub = LocalPrefixSpan .getSuffix(y, x)
139- (Array (y), sub)
140- }.filter(_._2.nonEmpty)
141- }
202+ // Map from prefix to set of possible next items from suffix
203+ val prefixToNextItems = prefixItemPairAndCounts
204+ .keys
205+ .groupByKey()
206+ .mapValues(_.toSet)
207+ .collect()
208+ .toMap
209+
210+
211+ // Frequent patterns with length N+1 and their corresponding counts
212+ val extendedPrefixAndCounts = prefixItemPairAndCounts
213+ .map { case ((prefix, item), count) => (item :: prefix, count) }
214+
215+ // Remaining work, all prefixes will have length N+1
216+ val extendedPrefixAndSuffix = prefixSuffixPairs
217+ .filter(x => prefixToNextItems.contains(x._1))
218+ .flatMap { case (prefix, suffix) =>
219+ val frequentNextItems = prefixToNextItems(prefix)
220+ val filteredSuffix = suffix.filter(frequentNextItems.contains(_))
221+ frequentNextItems.flatMap { item =>
222+ LocalPrefixSpan .getSuffix(item, filteredSuffix) match {
223+ case suffix if ! suffix.isEmpty => Some (item :: prefix, suffix)
224+ case _ => None
225+ }
226+ }
227+ }
228+
229+ (extendedPrefixAndCounts, extendedPrefixAndSuffix)
142230 }
143231
144232 /**
145- * calculate the patterns in local.
233+ * Calculate the patterns in local.
146234 * @param minCount the absolute minimum count
147- * @param data patterns and projected sequences data data
235+ * @param data prefixes and projected sequences data data
148236 * @return patterns
149237 */
150238 private def getPatternsInLocal (
151239 minCount : Long ,
152- data : RDD [(Array [Int ], Array [Array [Int ]])]): RDD [(Array [Int ], Long )] = {
153- data.flatMap { case (prefix, projDB) =>
154- LocalPrefixSpan .run(minCount, maxPatternLength, prefix.toList, projDB)
155- .map { case (pattern : List [Int ], count : Long ) => (pattern.toArray.reverse, count) }
240+ data : RDD [(List [Int ], Iterable [Array [Int ]])]): RDD [(List [Int ], Long )] = {
241+ data.flatMap {
242+ case (prefix, projDB) =>
243+ LocalPrefixSpan .run(minCount, maxPatternLength, prefix.toList.reverse, projDB)
244+ .map { case (pattern : List [Int ], count : Long ) =>
245+ (pattern.reverse, count)
246+ }
156247 }
157248 }
158249}
0 commit comments