1717
1818package org .apache .spark .mllib .fpm
1919
20+ import org .apache .spark .Logging
2021import org .apache .spark .annotation .Experimental
2122import org .apache .spark .rdd .RDD
23+ import org .apache .spark .storage .StorageLevel
2224
2325/**
2426 *
@@ -37,165 +39,206 @@ import org.apache.spark.rdd.RDD
3739 * (Wikipedia)]]
3840 */
3941@ Experimental
40- class PrefixSpan (
42+ class PrefixSpan private (
4143 private var minSupport : Double ,
42- private var maxPatternLength : Int ) extends java.io.Serializable {
43-
44- private var absMinSupport : Int = 0
44+ private var maxPatternLength : Int ) extends Logging with Serializable {
4545
4646 /**
4747 * Constructs a default instance with default parameters
48- * {minSupport: `0.1`, maxPatternLength: 10 }.
48+ * {minSupport: `0.1`, maxPatternLength: `10` }.
4949 */
5050 def this () = this (0.1 , 10 )
5151
5252 /**
5353 * Sets the minimal support level (default: `0.1`).
5454 */
5555 def setMinSupport (minSupport : Double ): this .type = {
56+ require(minSupport >= 0 && minSupport <= 1 )
5657 this .minSupport = minSupport
5758 this
5859 }
5960
6061 /**
61- * Sets maximal pattern length.
62+ * Sets maximal pattern length (default: `10`) .
6263 */
6364 def setMaxPatternLength (maxPatternLength : Int ): this .type = {
65+ require(maxPatternLength >= 1 )
6466 this .maxPatternLength = maxPatternLength
6567 this
6668 }
6769
6870 /**
69- * Calculate sequential patterns:
70- * a) find and collect length-one patterns
71- * b) for each length-one patterns and each sequence,
72- * emit (pattern (prefix), suffix sequence) as key-value pairs
73- * c) group by key and then map value iterator to array
74- * d) local PrefixSpan on each prefix
75- * @return sequential patterns
71+ * Find the complete set of sequential patterns in the input sequences.
72+ * @param sequences input data set, contains a set of sequences,
73+ * a sequence is an ordered list of elements.
74+ * @return a set of sequential pattern pairs,
75+ * the key of pair is pattern (a list of elements),
76+ * the value of pair is the pattern's support value.
7677 */
77- def run (sequences : RDD [Array [Int ]]): RDD [(Seq [Int ], Int )] = {
78- absMinSupport = getAbsoluteMinSupport(sequences)
78+ def run (sequences : RDD [Array [Int ]]): RDD [(Array [Int ], Long )] = {
79+ if (sequences.getStorageLevel == StorageLevel .NONE ) {
80+ logWarning(" Input data is not cached." )
81+ }
82+ val minCount = getAbsoluteMinSupport(sequences)
7983 val (lengthOnePatternsAndCounts, prefixAndCandidates) =
80- findLengthOnePatterns(sequences)
84+ findLengthOnePatterns(minCount, sequences)
8185 val repartitionedRdd = makePrefixProjectedDatabases(prefixAndCandidates)
82- val nextPatterns = getPatternsInLocal(repartitionedRdd)
83- val allPatterns = lengthOnePatternsAndCounts.map(x => (Seq (x._1), x._2)) ++ nextPatterns
86+ val nextPatterns = getPatternsInLocal(minCount, repartitionedRdd)
87+ val allPatterns = lengthOnePatternsAndCounts.map(x => (Array (x._1), x._2)) ++ nextPatterns
8488 allPatterns
8589 }
8690
87- private def getAbsoluteMinSupport (sequences : RDD [Array [Int ]]): Int = {
88- val result = if (minSupport <= 0 ) {
89- 0
90- } else {
91- val count = sequences.count()
92- val support = if (minSupport <= 1 ) minSupport else 1
93- (support * count).toInt
94- }
95- result
91+ /**
92+ * Get the absolute minimum support value (sequences count * minSupport).
93+ * @param sequences input data set, contains a set of sequences,
94+ * @return absolute minimum support value,
95+ */
96+ private def getAbsoluteMinSupport (sequences : RDD [Array [Int ]]): Long = {
97+ if (minSupport == 0 ) 0L else (sequences.count() * minSupport).toLong
9698 }
9799
98100 /**
99- * Find the patterns that it's length is one
101+ * Generates frequent items by filtering the input data using minimal support level.
102+ * @param minCount the absolute minimum support
100103 * @param sequences original sequences data
101- * @return length-one patterns and projection table
104+ * @return array of frequent pattern ordered by their frequencies
102105 */
103- private def findLengthOnePatterns (
104- sequences : RDD [ Array [ Int ]]) : ( RDD [( Int , Int )], RDD [( Seq [ Int ], Array [ Int ])]) = {
105- val LengthOnePatternAndCounts = sequences
106- .flatMap(_.distinct.map((_, 1 )))
106+ private def getFreqItemAndCounts (
107+ minCount : Long ,
108+ sequences : RDD [ Array [ Int ]]) : RDD [( Int , Long )] = {
109+ sequences .flatMap(_.distinct.map((_, 1L )))
107110 .reduceByKey(_ + _)
108- val infrequentLengthOnePatterns : Array [Int ] = LengthOnePatternAndCounts
109- .filter(_._2 < absMinSupport)
110- .map(_._1)
111- .collect()
112- val frequentLengthOnePatterns = LengthOnePatternAndCounts
113- .filter(_._2 >= absMinSupport)
114- val frequentLengthOnePatternsArray = frequentLengthOnePatterns
115- .map(_._1)
116- .collect()
117- val filteredSequences =
118- if (infrequentLengthOnePatterns.isEmpty) {
119- sequences
120- } else {
121- sequences.map { p =>
122- p.filter { x => ! infrequentLengthOnePatterns.contains(x) }
123- }
124- }
125- val prefixAndCandidates = filteredSequences.flatMap { x =>
126- frequentLengthOnePatternsArray.map { y =>
111+ .filter(_._2 >= minCount)
112+ }
113+
114+ /**
115+ * Generates frequent items by filtering the input data using minimal support level.
116+ * @param minCount the absolute minimum support
117+ * @param sequences sequences data
118+ * @return array of frequent pattern ordered by their frequencies
119+ */
120+ private def getFreqItemAndCounts (
121+ minCount : Long ,
122+ sequences : Array [Array [Int ]]): Array [(Int , Long )] = {
123+ sequences.flatMap(_.distinct)
124+ .groupBy(x => x)
125+ .mapValues(_.length.toLong)
126+ .filter(_._2 >= minCount)
127+ .toArray
128+ }
129+
130+ /**
131+ * Get the frequent prefixes' projected database.
132+ * @param frequentPrefixes frequent prefixes
133+ * @param sequences sequences data
134+ * @return prefixes and projected database
135+ */
136+ private def getPatternAndProjectedDatabase (
137+ frequentPrefixes : Array [Int ],
138+ sequences : RDD [Array [Int ]]): RDD [(Array [Int ], Array [Int ])] = {
139+ val filteredSequences = sequences.map { p =>
140+ p.filter (frequentPrefixes.contains(_) )
141+ }
142+ filteredSequences.flatMap { x =>
143+ frequentPrefixes.map { y =>
127144 val sub = getSuffix(y, x)
128- (Seq (y), sub)
145+ (Array (y), sub)
129146 }
130147 }.filter(x => x._2.nonEmpty)
131- (frequentLengthOnePatterns, prefixAndCandidates)
132148 }
133149
134150 /**
135- * Re-partition the RDD data, to get better balance and performance.
151+ * Get the frequent prefixes' projected database.
152+ * @param prePrefix the frequent prefixes' prefix
153+ * @param frequentPrefixes frequent prefixes
154+ * @param sequences sequences data
155+ * @return prefixes and projected database
156+ */
157+ private def getPatternAndProjectedDatabase (
158+ prePrefix : Array [Int ],
159+ frequentPrefixes : Array [Int ],
160+ sequences : Array [Array [Int ]]): Array [(Array [Int ], Array [Array [Int ]])] = {
161+ val filteredProjectedDatabase = sequences
162+ .map(x => x.filter(frequentPrefixes.contains(_)))
163+ frequentPrefixes.map { x =>
164+ val sub = filteredProjectedDatabase.map(y => getSuffix(x, y)).filter(_.nonEmpty)
165+ (prePrefix ++ Array (x), sub)
166+ }.filter(x => x._2.nonEmpty)
167+ }
168+
169+ /**
170+ * Find the patterns that it's length is one
171+ * @param minCount the absolute minimum support
172+ * @param sequences original sequences data
173+ * @return length-one patterns and projection table
174+ */
175+ private def findLengthOnePatterns (
176+ minCount : Long ,
177+ sequences : RDD [Array [Int ]]): (RDD [(Int , Long )], RDD [(Array [Int ], Array [Int ])]) = {
178+ val frequentLengthOnePatternAndCounts = getFreqItemAndCounts(minCount, sequences)
179+ val prefixAndProjectedDatabase = getPatternAndProjectedDatabase(
180+ frequentLengthOnePatternAndCounts.keys.collect(), sequences)
181+ (frequentLengthOnePatternAndCounts, prefixAndProjectedDatabase)
182+ }
183+
184+ /**
185+ * Constructs prefix-projected databases from (prefix, suffix) pairs.
136186 * @param data patterns and projected sequences data before re-partition
137187 * @return patterns and projected sequences data after re-partition
138188 */
139189 private def makePrefixProjectedDatabases (
140- data : RDD [(Seq [Int ], Array [Int ])]): RDD [(Seq [Int ], Array [Array [Int ]])] = {
141- val dataMerged = data
190+ data : RDD [(Array [Int ], Array [Int ])]): RDD [(Array [Int ], Array [Array [Int ]])] = {
191+ data.map(x => (x._1.toSeq, x._2))
142192 .groupByKey()
143- .mapValues(_.toArray)
144- dataMerged
193+ .map(x => (x._1.toArray, x._2.toArray))
145194 }
146195
147196 /**
148197 * calculate the patterns in local.
198+ * @param minCount the absolute minimum support
149199 * @param data patterns and projected sequences data data
150200 * @return patterns
151201 */
152202 private def getPatternsInLocal (
153- data : RDD [(Seq [Int ], Array [Array [Int ]])]): RDD [(Seq [Int ], Int )] = {
154- val result = data.flatMap { x =>
155- getPatternsWithPrefix(x._1, x._2)
203+ minCount : Long ,
204+ data : RDD [(Array [Int ], Array [Array [Int ]])]): RDD [(Array [Int ], Long )] = {
205+ data.flatMap { x =>
206+ getPatternsWithPrefix(minCount, x._1, x._2)
156207 }
157- result
158208 }
159209
160210 /**
161211 * calculate the patterns with one prefix in local.
212+ * @param minCount the absolute minimum support
162213 * @param prefix prefix
163214 * @param projectedDatabase patterns and projected sequences data
164215 * @return patterns
165216 */
166217 private def getPatternsWithPrefix (
167- prefix : Seq [Int ],
168- projectedDatabase : Array [Array [Int ]]): Array [(Seq [Int ], Int )] = {
169- val prefixAndCounts = projectedDatabase
170- .flatMap(_.distinct)
171- .groupBy(x => x)
172- .mapValues(_.length)
173- val frequentPrefixExtensions = prefixAndCounts.filter(x => x._2 >= absMinSupport)
174- val frequentPrefixesAndCounts = frequentPrefixExtensions
175- .map(x => (prefix ++ Seq (x._1), x._2))
176- .toArray
177- val cleanedSearchSpace = projectedDatabase
178- .map(x => x.filter(y => frequentPrefixExtensions.contains(y)))
179- val prefixProjectedDatabases = frequentPrefixExtensions.map { x =>
180- val sub = cleanedSearchSpace.map(y => getSuffix(x._1, y)).filter(_.nonEmpty)
181- (prefix ++ Seq (x._1), sub)
182- }.filter(x => x._2.nonEmpty)
183- .toArray
218+ minCount : Long ,
219+ prefix : Array [Int ],
220+ projectedDatabase : Array [Array [Int ]]): Array [(Array [Int ], Long )] = {
221+ val frequentPrefixAndCounts = getFreqItemAndCounts(minCount, projectedDatabase)
222+ val frequentPatternAndCounts = frequentPrefixAndCounts
223+ .map(x => (prefix ++ Array (x._1), x._2))
224+ val prefixProjectedDatabases = getPatternAndProjectedDatabase(
225+ prefix, frequentPrefixAndCounts.map(_._1), projectedDatabase)
226+
184227 val continueProcess = prefixProjectedDatabases.nonEmpty && prefix.length + 1 < maxPatternLength
185228 if (continueProcess) {
186229 val nextPatterns = prefixProjectedDatabases
187- .map(x => getPatternsWithPrefix(x._1, x._2))
230+ .map(x => getPatternsWithPrefix(minCount, x._1, x._2))
188231 .reduce(_ ++ _)
189- frequentPrefixesAndCounts ++ nextPatterns
232+ frequentPatternAndCounts ++ nextPatterns
190233 } else {
191- frequentPrefixesAndCounts
234+ frequentPatternAndCounts
192235 }
193236 }
194237
195238 /**
196239 * calculate suffix sequence following a prefix in a sequence
197240 * @param prefix prefix
198- * @param sequence original sequence
241+ * @param sequence sequence
199242 * @return suffix sequence
200243 */
201244 private def getSuffix (prefix : Int , sequence : Array [Int ]): Array [Int ] = {
0 commit comments