@@ -45,7 +45,7 @@ class PrefixSpan private (
4545 private var minSupport : Double ,
4646 private var maxPatternLength : Int ) extends Logging with Serializable {
4747
48- private val minPatternsBeforeLocalProcessing : Int = 20
48+ private val maxSuffixesBeforeLocalProcessing : Long = 10000
4949
5050 /**
5151 * Constructs a default instance with default parameters
@@ -91,20 +91,25 @@ class PrefixSpan private (
9191 lengthOnePatternsAndCounts.map(_._1).collect(), sequences)
9292 var patternsCount : Long = lengthOnePatternsAndCounts.count()
9393 var allPatternAndCounts = lengthOnePatternsAndCounts.map(x => (ArrayBuffer (x._1), x._2))
94- var currentPrefixSuffixPairs = prefixSuffixPairs
94+ var (smallPrefixSuffixPairs, largePrefixSuffixPairs) =
95+ splitPrefixSuffixPairs(prefixSuffixPairs)
96+ largePrefixSuffixPairs.persist(StorageLevel .MEMORY_AND_DISK )
9597 var patternLength : Int = 1
9698 while (patternLength < maxPatternLength &&
97- patternsCount <= minPatternsBeforeLocalProcessing &&
98- currentPrefixSuffixPairs.count() != 0 ) {
99+ largePrefixSuffixPairs.count() != 0 ) {
99100 val (nextPatternAndCounts, nextPrefixSuffixPairs) =
100- getPatternCountsAndPrefixSuffixPairs(minCount, currentPrefixSuffixPairs )
101+ getPatternCountsAndPrefixSuffixPairs(minCount, largePrefixSuffixPairs )
101102 patternsCount = nextPatternAndCounts.count()
102- currentPrefixSuffixPairs = nextPrefixSuffixPairs
103+ largePrefixSuffixPairs.unpersist()
104+ val splitedPrefixSuffixPairs = splitPrefixSuffixPairs(nextPrefixSuffixPairs)
105+ largePrefixSuffixPairs = splitedPrefixSuffixPairs._2
106+ largePrefixSuffixPairs.persist(StorageLevel .MEMORY_AND_DISK )
107+ smallPrefixSuffixPairs = smallPrefixSuffixPairs ++ splitedPrefixSuffixPairs._1
103108 allPatternAndCounts = allPatternAndCounts ++ nextPatternAndCounts
104109 patternLength = patternLength + 1
105110 }
106- if (patternLength < maxPatternLength && patternsCount > 0 ) {
107- val projectedDatabase = currentPrefixSuffixPairs
111+ if (smallPrefixSuffixPairs.count() > 0 ) {
112+ val projectedDatabase = smallPrefixSuffixPairs
108113 .map(x => (x._1.toSeq, x._2))
109114 .groupByKey()
110115 .map(x => (x._1.toArray, x._2.toArray))
@@ -114,6 +119,38 @@ class PrefixSpan private (
114119 allPatternAndCounts.map { case (pattern, count) => (pattern.toArray, count) }
115120 }
116121
122+
123+ /**
124+ * Split prefix suffix pairs to two parts:
125+ * suffixes' size less than maxSuffixesBeforeLocalProcessing and
126+ * suffixes' size more than maxSuffixesBeforeLocalProcessing
127+ * @param prefixSuffixPairs prefix (length n) and suffix pairs,
128+ * @return small size prefix suffix pairs and big size prefix suffix pairs
129+ * (RDD[prefix, suffix], RDD[prefix, suffix ])
130+ */
131+ private def splitPrefixSuffixPairs (
132+ prefixSuffixPairs : RDD [(ArrayBuffer [Int ], Array [Int ])]):
133+ (RDD [(ArrayBuffer [Int ], Array [Int ])], RDD [(ArrayBuffer [Int ], Array [Int ])]) = {
134+ val suffixSizeMap = prefixSuffixPairs
135+ .map(x => (x._1, x._2.length))
136+ .reduceByKey(_ + _)
137+ .map(x => (x._2 <= maxSuffixesBeforeLocalProcessing, Set (x._1)))
138+ .reduceByKey(_ ++ _)
139+ .collect
140+ .toMap
141+ val small = if (suffixSizeMap.contains(true )) {
142+ prefixSuffixPairs.filter(x => suffixSizeMap(true ).contains(x._1))
143+ } else {
144+ prefixSuffixPairs.filter(x => false )
145+ }
146+ val large = if (suffixSizeMap.contains(false )) {
147+ prefixSuffixPairs.filter(x => suffixSizeMap(false ).contains(x._1))
148+ } else {
149+ prefixSuffixPairs.filter(x => false )
150+ }
151+ (small, large)
152+ }
153+
117154 /**
118155 * Get the pattern and counts, and prefix suffix pairs
119156 * @param minCount minimum count
@@ -205,7 +242,7 @@ class PrefixSpan private (
205242 data : RDD [(Array [Int ], Array [Array [Int ]])]): RDD [(ArrayBuffer [Int ], Long )] = {
206243 data.flatMap {
207244 case (prefix, projDB) =>
208- LocalPrefixSpan .run(minCount, maxPatternLength, prefix.toList, projDB)
245+ LocalPrefixSpan .run(minCount, maxPatternLength, prefix.toList.reverse , projDB)
209246 .map { case (pattern : List [Int ], count : Long ) =>
210247 (pattern.toArray.reverse.to[ArrayBuffer ], count)
211248 }
0 commit comments