Skip to content

Commit 4dd1c8a

Browse files
committed
initialize file before rebase.
1 parent 078d410 commit 4dd1c8a

File tree

1 file changed

+10
-65
lines changed

1 file changed

+10
-65
lines changed

mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala

Lines changed: 10 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,6 @@ class PrefixSpan private (
4343
private var minSupport: Double,
4444
private var maxPatternLength: Int) extends Logging with Serializable {
4545

46-
private val minPatternsBeforeShuffle: Int = 20
47-
4846
/**
4947
* Constructs a default instance with default parameters
5048
* {minSupport: `0.1`, maxPatternLength: `10`}.
@@ -88,69 +86,16 @@ class PrefixSpan private (
8886
getFreqItemAndCounts(minCount, sequences).collect()
8987
val prefixAndProjectedDatabase = getPrefixAndProjectedDatabase(
9088
lengthOnePatternsAndCounts.map(_._1), sequences)
91-
92-
var patternsCount = lengthOnePatternsAndCounts.length
93-
var allPatternAndCounts = sequences.sparkContext.parallelize(
94-
lengthOnePatternsAndCounts.map(x => (Array(x._1), x._2)))
95-
var currentProjectedDatabase = prefixAndProjectedDatabase
96-
while (patternsCount <= minPatternsBeforeShuffle &&
97-
currentProjectedDatabase.count() != 0) {
98-
val (nextPatternAndCounts, nextProjectedDatabase) =
99-
getPatternCountsAndProjectedDatabase(minCount, currentProjectedDatabase)
100-
patternsCount = nextPatternAndCounts.count().toInt
101-
currentProjectedDatabase = nextProjectedDatabase
102-
allPatternAndCounts = allPatternAndCounts ++ nextPatternAndCounts
103-
}
104-
if (patternsCount > 0) {
105-
val groupedProjectedDatabase = currentProjectedDatabase
106-
.map(x => (x._1.toSeq, x._2))
107-
.groupByKey()
108-
.map(x => (x._1.toArray, x._2.toArray))
109-
val nextPatternAndCounts = getPatternsInLocal(minCount, groupedProjectedDatabase)
110-
allPatternAndCounts = allPatternAndCounts ++ nextPatternAndCounts
111-
}
112-
allPatternAndCounts
113-
}
114-
115-
/**
116-
* Get the pattern and counts, and projected database
117-
* @param minCount minimum count
118-
* @param prefixAndProjectedDatabase prefix and projected database,
119-
* @return pattern and counts, and projected database
120-
* (Array[pattern, count], RDD[prefix, projected database ])
121-
*/
122-
private def getPatternCountsAndProjectedDatabase(
123-
minCount: Long,
124-
prefixAndProjectedDatabase: RDD[(Array[Int], Array[Int])]):
125-
(RDD[(Array[Int], Long)], RDD[(Array[Int], Array[Int])]) = {
126-
val prefixAndFreqentItemAndCounts = prefixAndProjectedDatabase.flatMap{ x =>
127-
x._2.distinct.map(y => ((x._1.toSeq, y), 1L))
128-
}.reduceByKey(_ + _)
129-
.filter(_._2 >= minCount)
130-
val patternAndCounts = prefixAndFreqentItemAndCounts
131-
.map(x => (x._1._1.toArray ++ Array(x._1._2), x._2))
132-
val prefixlength = prefixAndProjectedDatabase.take(1)(0)._1.length
133-
if (prefixlength + 1 >= maxPatternLength) {
134-
(patternAndCounts, prefixAndProjectedDatabase.filter(x => false))
135-
} else {
136-
val frequentItemsMap = prefixAndFreqentItemAndCounts
137-
.keys.map(x => (x._1, x._2))
138-
.groupByKey()
139-
.mapValues(_.toSet)
140-
.collect
141-
.toMap
142-
val nextPrefixAndProjectedDatabase = prefixAndProjectedDatabase
143-
.filter(x => frequentItemsMap.contains(x._1))
144-
.flatMap { x =>
145-
val frequentItemSet = frequentItemsMap(x._1)
146-
val filteredSequence = x._2.filter(frequentItemSet.contains(_))
147-
val subProjectedDabase = frequentItemSet.map{ y =>
148-
(y, LocalPrefixSpan.getSuffix(y, filteredSequence))
149-
}.filter(_._2.nonEmpty)
150-
subProjectedDabase.map(y => (x._1 ++ Array(y._1), y._2))
151-
}
152-
(patternAndCounts, nextPrefixAndProjectedDatabase)
153-
}
89+
val groupedProjectedDatabase = prefixAndProjectedDatabase
90+
.map(x => (x._1.toSeq, x._2))
91+
.groupByKey()
92+
.map(x => (x._1.toArray, x._2.toArray))
93+
val nextPatterns = getPatternsInLocal(minCount, groupedProjectedDatabase)
94+
val lengthOnePatternsAndCountsRdd =
95+
sequences.sparkContext.parallelize(
96+
lengthOnePatternsAndCounts.map(x => (Array(x._1), x._2)))
97+
val allPatterns = lengthOnePatternsAndCountsRdd ++ nextPatterns
98+
allPatterns
15499
}
155100

156101
/**

0 commit comments

Comments
 (0)