@@ -53,7 +53,8 @@ class PrefixSpan private (
5353 * Sets the minimal support level (default: `0.1`).
5454 */
5555 def setMinSupport (minSupport : Double ): this .type = {
56- require(minSupport >= 0 && minSupport <= 1 )
56+ require(minSupport >= 0 && minSupport <= 1 ,
57+ " The minimum support value must be between 0 and 1, including 0 and 1." )
5758 this .minSupport = minSupport
5859 this
5960 }
@@ -62,7 +63,8 @@ class PrefixSpan private (
6263 * Sets maximal pattern length (default: `10`).
6364 */
6465 def setMaxPatternLength (maxPatternLength : Int ): this .type = {
65- require(maxPatternLength >= 1 )
66+ require(maxPatternLength >= 1 ,
67+ " The maximum pattern length value must be greater than 0." )
6668 this .maxPatternLength = maxPatternLength
6769 this
6870 }
@@ -73,35 +75,38 @@ class PrefixSpan private (
7375 * a sequence is an ordered list of elements.
7476 * @return a set of sequential pattern pairs,
7577 * the key of pair is pattern (a list of elements),
76- * the value of pair is the pattern's support value .
78+ * the value of pair is the pattern's count .
7779 */
7880 def run (sequences : RDD [Array [Int ]]): RDD [(Array [Int ], Long )] = {
7981 if (sequences.getStorageLevel == StorageLevel .NONE ) {
8082 logWarning(" Input data is not cached." )
8183 }
82- val minCount = getAbsoluteMinSupport (sequences)
84+ val minCount = getMinCount (sequences)
8385 val (lengthOnePatternsAndCounts, prefixAndCandidates) =
8486 findLengthOnePatterns(minCount, sequences)
85- val repartitionedRdd = makePrefixProjectedDatabases(prefixAndCandidates)
86- val nextPatterns = getPatternsInLocal(minCount, repartitionedRdd)
87- val allPatterns = lengthOnePatternsAndCounts.map(x => (Array (x._1), x._2)) ++ nextPatterns
87+ val projectedDatabase = makePrefixProjectedDatabases(prefixAndCandidates)
88+ val nextPatterns = getPatternsInLocal(minCount, projectedDatabase)
89+ val lengthOnePatternsAndCountsRdd =
90+ sequences.sparkContext.parallelize(
91+ lengthOnePatternsAndCounts.map(x => (Array (x._1), x._2)))
92+ val allPatterns = lengthOnePatternsAndCountsRdd ++ nextPatterns
8893 allPatterns
8994 }
9095
9196 /**
92- * Get the absolute minimum support value (sequences count * minSupport).
97+ * Get the minimum count (sequences count * minSupport).
9398 * @param sequences input data set, contains a set of sequences,
94- * @return absolute minimum support value ,
99+ * @return minimum count ,
95100 */
96- private def getAbsoluteMinSupport (sequences : RDD [Array [Int ]]): Long = {
97- if (minSupport == 0 ) 0L else (sequences.count() * minSupport).toLong
101+ private def getMinCount (sequences : RDD [Array [Int ]]): Long = {
102+ if (minSupport == 0 ) 0L else math.ceil (sequences.count() * minSupport).toLong
98103 }
99104
100105 /**
101- * Generates frequent items by filtering the input data using minimal support level.
102- * @param minCount the absolute minimum support
106+ * Generates frequent items by filtering the input data using minimal count level.
107+ * @param minCount the absolute minimum count
103108 * @param sequences original sequences data
104- * @return array of frequent pattern ordered by their frequencies
109+ * @return array of item and count pair
105110 */
106111 private def getFreqItemAndCounts (
107112 minCount : Long ,
@@ -111,22 +116,6 @@ class PrefixSpan private (
111116 .filter(_._2 >= minCount)
112117 }
113118
114- /**
115- * Generates frequent items by filtering the input data using minimal support level.
116- * @param minCount the absolute minimum support
117- * @param sequences sequences data
118- * @return array of frequent pattern ordered by their frequencies
119- */
120- private def getFreqItemAndCounts (
121- minCount : Long ,
122- sequences : Array [Array [Int ]]): Array [(Int , Long )] = {
123- sequences.flatMap(_.distinct)
124- .groupBy(x => x)
125- .mapValues(_.length.toLong)
126- .filter(_._2 >= minCount)
127- .toArray
128- }
129-
130119 /**
131120 * Get the frequent prefixes' projected database.
132121 * @param frequentPrefixes frequent prefixes
@@ -141,44 +130,25 @@ class PrefixSpan private (
141130 }
142131 filteredSequences.flatMap { x =>
143132 frequentPrefixes.map { y =>
144- val sub = getSuffix(y, x)
133+ val sub = LocalPrefixSpan . getSuffix(y, x)
145134 (Array (y), sub)
146- }
147- }.filter(x => x._2.nonEmpty)
148- }
149-
150- /**
151- * Get the frequent prefixes' projected database.
152- * @param prePrefix the frequent prefixes' prefix
153- * @param frequentPrefixes frequent prefixes
154- * @param sequences sequences data
155- * @return prefixes and projected database
156- */
157- private def getPatternAndProjectedDatabase (
158- prePrefix : Array [Int ],
159- frequentPrefixes : Array [Int ],
160- sequences : Array [Array [Int ]]): Array [(Array [Int ], Array [Array [Int ]])] = {
161- val filteredProjectedDatabase = sequences
162- .map(x => x.filter(frequentPrefixes.contains(_)))
163- frequentPrefixes.map { x =>
164- val sub = filteredProjectedDatabase.map(y => getSuffix(x, y)).filter(_.nonEmpty)
165- (prePrefix ++ Array (x), sub)
166- }.filter(x => x._2.nonEmpty)
135+ }.filter(_._2.nonEmpty)
136+ }
167137 }
168138
169139 /**
170140 * Find the patterns that it's length is one
171- * @param minCount the absolute minimum support
141+ * @param minCount the minimum count
172142 * @param sequences original sequences data
173143 * @return length-one patterns and projection table
174144 */
175145 private def findLengthOnePatterns (
176146 minCount : Long ,
177- sequences : RDD [Array [Int ]]): (RDD [(Int , Long )], RDD [(Array [Int ], Array [Int ])]) = {
147+ sequences : RDD [Array [Int ]]): (Array [(Int , Long )], RDD [(Array [Int ], Array [Int ])]) = {
178148 val frequentLengthOnePatternAndCounts = getFreqItemAndCounts(minCount, sequences)
179149 val prefixAndProjectedDatabase = getPatternAndProjectedDatabase(
180150 frequentLengthOnePatternAndCounts.keys.collect(), sequences)
181- (frequentLengthOnePatternAndCounts, prefixAndProjectedDatabase)
151+ (frequentLengthOnePatternAndCounts.collect() , prefixAndProjectedDatabase)
182152 }
183153
184154 /**
@@ -195,58 +165,15 @@ class PrefixSpan private (
195165
196166 /**
197167 * calculate the patterns in local.
198- * @param minCount the absolute minimum support
168+ * @param minCount the absolute minimum count
199169 * @param data patterns and projected sequences data data
200170 * @return patterns
201171 */
202172 private def getPatternsInLocal (
203173 minCount : Long ,
204174 data : RDD [(Array [Int ], Array [Array [Int ]])]): RDD [(Array [Int ], Long )] = {
205175 data.flatMap { x =>
206- getPatternsWithPrefix(minCount, x._1, x._2)
207- }
208- }
209-
210- /**
211- * calculate the patterns with one prefix in local.
212- * @param minCount the absolute minimum support
213- * @param prefix prefix
214- * @param projectedDatabase patterns and projected sequences data
215- * @return patterns
216- */
217- private def getPatternsWithPrefix (
218- minCount : Long ,
219- prefix : Array [Int ],
220- projectedDatabase : Array [Array [Int ]]): Array [(Array [Int ], Long )] = {
221- val frequentPrefixAndCounts = getFreqItemAndCounts(minCount, projectedDatabase)
222- val frequentPatternAndCounts = frequentPrefixAndCounts
223- .map(x => (prefix ++ Array (x._1), x._2))
224- val prefixProjectedDatabases = getPatternAndProjectedDatabase(
225- prefix, frequentPrefixAndCounts.map(_._1), projectedDatabase)
226-
227- val continueProcess = prefixProjectedDatabases.nonEmpty && prefix.length + 1 < maxPatternLength
228- if (continueProcess) {
229- val nextPatterns = prefixProjectedDatabases
230- .map(x => getPatternsWithPrefix(minCount, x._1, x._2))
231- .reduce(_ ++ _)
232- frequentPatternAndCounts ++ nextPatterns
233- } else {
234- frequentPatternAndCounts
235- }
236- }
237-
238- /**
239- * calculate suffix sequence following a prefix in a sequence
240- * @param prefix prefix
241- * @param sequence sequence
242- * @return suffix sequence
243- */
244- private def getSuffix (prefix : Int , sequence : Array [Int ]): Array [Int ] = {
245- val index = sequence.indexOf(prefix)
246- if (index == - 1 ) {
247- Array ()
248- } else {
249- sequence.drop(index + 1 )
176+ LocalPrefixSpan .run(minCount, maxPatternLength, x._1, x._2)
250177 }
251178 }
252179}
0 commit comments