-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-8998][MLlib] Collect enough frequent prefixes before local processing in PrefixSpan (new) #7412
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-8998][MLlib] Collect enough frequent prefixes before local processing in PrefixSpan (new) #7412
Changes from all commits
91fd7e6
575995f
951fd42
a2eb14c
89bc368
1dd33ad
4c60fb3
ba5df34
574e56c
ca9c4c8
22b0ef4
078d410
4dd1c8a
a8fde87
6560c69
baa2885
095aa3a
b07e20c
d2250b7
64271b3
6e149fa
01c9ae9
cb2a4fc
da0091b
1235cfc
c2caa5c
87fa021
ad23aa9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,6 +17,8 @@ | |
|
|
||
| package org.apache.spark.mllib.fpm | ||
|
|
||
| import scala.collection.mutable.ArrayBuffer | ||
|
|
||
| import org.apache.spark.Logging | ||
| import org.apache.spark.annotation.Experimental | ||
| import org.apache.spark.rdd.RDD | ||
|
|
@@ -43,28 +45,44 @@ class PrefixSpan private ( | |
| private var minSupport: Double, | ||
| private var maxPatternLength: Int) extends Logging with Serializable { | ||
|
|
||
| /** | ||
| * The maximum number of items allowed in a projected database before local processing. If a | ||
| * projected database exceeds this size, another iteration of distributed PrefixSpan is run. | ||
| */ | ||
| private val maxLocalProjDBSize: Long = 10000 | ||
|
|
||
| /** | ||
| * Constructs a default instance with default parameters | ||
| * {minSupport: `0.1`, maxPatternLength: `10`}. | ||
| */ | ||
| def this() = this(0.1, 10) | ||
|
|
||
| /** | ||
| * Get the minimal support (i.e. the frequency of occurrence before a pattern is considered | ||
| * frequent). | ||
| */ | ||
| def getMinSupport(): Double = this.minSupport | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK |
||
|
|
||
| /** | ||
| * Sets the minimal support level (default: `0.1`). | ||
| */ | ||
| def setMinSupport(minSupport: Double): this.type = { | ||
| require(minSupport >= 0 && minSupport <= 1, | ||
| "The minimum support value must be between 0 and 1, including 0 and 1.") | ||
| require(minSupport >= 0 && minSupport <= 1, "The minimum support value must be in [0, 1].") | ||
| this.minSupport = minSupport | ||
| this | ||
| } | ||
|
|
||
| /** | ||
| * Gets the maximal pattern length (i.e. the length of the longest sequential pattern to consider. | ||
| */ | ||
| def getMaxPatternLength(): Double = this.maxPatternLength | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK |
||
|
|
||
| /** | ||
| * Sets maximal pattern length (default: `10`). | ||
| */ | ||
| def setMaxPatternLength(maxPatternLength: Int): this.type = { | ||
| require(maxPatternLength >= 1, | ||
| "The maximum pattern length value must be greater than 0.") | ||
| // TODO: support unbounded pattern length when maxPatternLength = 0 | ||
| require(maxPatternLength >= 1, "The maximum pattern length value must be greater than 0.") | ||
| this.maxPatternLength = maxPatternLength | ||
| this | ||
| } | ||
|
|
@@ -81,78 +99,145 @@ class PrefixSpan private ( | |
| if (sequences.getStorageLevel == StorageLevel.NONE) { | ||
| logWarning("Input data is not cached.") | ||
| } | ||
| val minCount = getMinCount(sequences) | ||
| val lengthOnePatternsAndCounts = | ||
| getFreqItemAndCounts(minCount, sequences).collect() | ||
| val prefixAndProjectedDatabase = getPrefixAndProjectedDatabase( | ||
| lengthOnePatternsAndCounts.map(_._1), sequences) | ||
| val groupedProjectedDatabase = prefixAndProjectedDatabase | ||
| .map(x => (x._1.toSeq, x._2)) | ||
| .groupByKey() | ||
| .map(x => (x._1.toArray, x._2.toArray)) | ||
| val nextPatterns = getPatternsInLocal(minCount, groupedProjectedDatabase) | ||
| val lengthOnePatternsAndCountsRdd = | ||
| sequences.sparkContext.parallelize( | ||
| lengthOnePatternsAndCounts.map(x => (Array(x._1), x._2))) | ||
| val allPatterns = lengthOnePatternsAndCountsRdd ++ nextPatterns | ||
| allPatterns | ||
|
|
||
| // Convert min support to a min number of transactions for this dataset | ||
| val minCount = if (minSupport == 0) 0L else math.ceil(sequences.count() * minSupport).toLong | ||
|
|
||
| // (Frequent items -> number of occurrences, all items here satisfy the `minSupport` threshold | ||
| val freqItemCounts = sequences | ||
| .flatMap(seq => seq.distinct.map(item => (item, 1L))) | ||
| .reduceByKey(_ + _) | ||
| .filter(_._2 >= minCount) | ||
|
|
||
| // Pairs of (length 1 prefix, suffix consisting of frequent items) | ||
| val itemSuffixPairs = { | ||
| val freqItems = freqItemCounts.keys.collect().toSet | ||
| sequences.flatMap { seq => | ||
| val filteredSeq = seq.filter(freqItems.contains(_)) | ||
| freqItems.flatMap { item => | ||
| val candidateSuffix = LocalPrefixSpan.getSuffix(item, filteredSeq) | ||
| candidateSuffix match { | ||
| case suffix if !suffix.isEmpty => Some((List(item), suffix)) | ||
| case _ => None | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| // Accumulator for the computed results to be returned, initialized to the frequent items (i.e. | ||
| // frequent length-one prefixes) | ||
| var resultsAccumulator = freqItemCounts.map(x => (List(x._1), x._2)) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can collect |
||
|
|
||
| // Remaining work to be locally and distributively processed respectfully | ||
| var (pairsForLocal, pairsForDistributed) = partitionByProjDBSize(itemSuffixPairs) | ||
|
|
||
| // Continue processing until no pairs for distributed processing remain (i.e. all prefixes have | ||
| // projected database sizes <= `maxLocalProjDBSize`) | ||
| while (pairsForDistributed.count() != 0) { | ||
| val (nextPatternAndCounts, nextPrefixSuffixPairs) = | ||
| extendPrefixes(minCount, pairsForDistributed) | ||
| pairsForDistributed.unpersist() | ||
| val (smallerPairsPart, largerPairsPart) = partitionByProjDBSize(nextPrefixSuffixPairs) | ||
| pairsForDistributed = largerPairsPart | ||
| pairsForDistributed.persist(StorageLevel.MEMORY_AND_DISK) | ||
| pairsForLocal ++= smallerPairsPart | ||
| resultsAccumulator ++= nextPatternAndCounts | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This might create many small partitions in the output. The small patterns could be collected to local and make a single partition. |
||
| } | ||
|
|
||
| // Process the small projected databases locally | ||
| resultsAccumulator ++= getPatternsInLocal(minCount, pairsForLocal.groupByKey()) | ||
|
|
||
| resultsAccumulator.map { case (pattern, count) => (pattern.toArray, count) } | ||
| } | ||
|
|
||
|
|
||
| /** | ||
| * Get the minimum count (sequences count * minSupport). | ||
| * @param sequences input data set, contains a set of sequences, | ||
| * @return minimum count, | ||
| * Partitions the prefix-suffix pairs by projected database size. | ||
| * @param prefixSuffixPairs prefix (length n) and suffix pairs, | ||
| * @return prefix-suffix pairs partitioned by whether their projected database size is <= or | ||
| * greater than [[maxLocalProjDBSize]] | ||
| */ | ||
| private def getMinCount(sequences: RDD[Array[Int]]): Long = { | ||
| if (minSupport == 0) 0L else math.ceil(sequences.count() * minSupport).toLong | ||
| private def partitionByProjDBSize(prefixSuffixPairs: RDD[(List[Int], Array[Int])]) | ||
| : (RDD[(List[Int], Array[Int])], RDD[(List[Int], Array[Int])]) = { | ||
| val prefixToSuffixSize = prefixSuffixPairs | ||
| .aggregateByKey(0)( | ||
| seqOp = { case (count, suffix) => count + suffix.length }, | ||
| combOp = { _ + _ }) | ||
| val smallPrefixes = prefixToSuffixSize | ||
| .filter(_._2 <= maxLocalProjDBSize) | ||
| .keys | ||
| .collect() | ||
| .toSet | ||
| val small = prefixSuffixPairs.filter { case (prefix, _) => smallPrefixes.contains(prefix) } | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should collect |
||
| val large = prefixSuffixPairs.filter { case (prefix, _) => !smallPrefixes.contains(prefix) } | ||
| (small, large) | ||
| } | ||
|
|
||
| /** | ||
| * Generates frequent items by filtering the input data using minimal count level. | ||
| * @param minCount the absolute minimum count | ||
| * @param sequences original sequences data | ||
| * @return array of item and count pair | ||
| * Extends all prefixes by one item from their suffix and computes the resulting frequent prefixes | ||
| * and remaining work. | ||
| * @param minCount minimum count | ||
| * @param prefixSuffixPairs prefix (length N) and suffix pairs, | ||
| * @return (frequent length N+1 extended prefix, count) pairs and (frequent length N+1 extended | ||
| * prefix, corresponding suffix) pairs. | ||
| */ | ||
| private def getFreqItemAndCounts( | ||
| private def extendPrefixes( | ||
| minCount: Long, | ||
| sequences: RDD[Array[Int]]): RDD[(Int, Long)] = { | ||
| sequences.flatMap(_.distinct.map((_, 1L))) | ||
| prefixSuffixPairs: RDD[(List[Int], Array[Int])]) | ||
| : (RDD[(List[Int], Long)], RDD[(List[Int], Array[Int])]) = { | ||
|
|
||
| // (length N prefix, item from suffix) pairs and their corresponding number of occurrences | ||
| // Every (prefix :+ suffix) is guaranteed to have support exceeding `minSupport` | ||
| val prefixItemPairAndCounts = prefixSuffixPairs | ||
| .flatMap { case (prefix, suffix) => suffix.distinct.map(y => ((prefix, y), 1L)) } | ||
| .reduceByKey(_ + _) | ||
| .filter(_._2 >= minCount) | ||
| } | ||
|
|
||
| /** | ||
| * Get the frequent prefixes' projected database. | ||
| * @param frequentPrefixes frequent prefixes | ||
| * @param sequences sequences data | ||
| * @return prefixes and projected database | ||
| */ | ||
| private def getPrefixAndProjectedDatabase( | ||
| frequentPrefixes: Array[Int], | ||
| sequences: RDD[Array[Int]]): RDD[(Array[Int], Array[Int])] = { | ||
| val filteredSequences = sequences.map { p => | ||
| p.filter (frequentPrefixes.contains(_) ) | ||
| } | ||
| filteredSequences.flatMap { x => | ||
| frequentPrefixes.map { y => | ||
| val sub = LocalPrefixSpan.getSuffix(y, x) | ||
| (Array(y), sub) | ||
| }.filter(_._2.nonEmpty) | ||
| } | ||
| // Map from prefix to set of possible next items from suffix | ||
| val prefixToNextItems = prefixItemPairAndCounts | ||
| .keys | ||
| .groupByKey() | ||
| .mapValues(_.toSet) | ||
| .collect() | ||
| .toMap | ||
|
|
||
|
|
||
| // Frequent patterns with length N+1 and their corresponding counts | ||
| val extendedPrefixAndCounts = prefixItemPairAndCounts | ||
| .map { case ((prefix, item), count) => (item :: prefix, count) } | ||
|
|
||
| // Remaining work, all prefixes will have length N+1 | ||
| val extendedPrefixAndSuffix = prefixSuffixPairs | ||
| .filter(x => prefixToNextItems.contains(x._1)) | ||
| .flatMap { case (prefix, suffix) => | ||
| val frequentNextItems = prefixToNextItems(prefix) | ||
| val filteredSuffix = suffix.filter(frequentNextItems.contains(_)) | ||
| frequentNextItems.flatMap { item => | ||
| LocalPrefixSpan.getSuffix(item, filteredSuffix) match { | ||
| case suffix if !suffix.isEmpty => Some(item :: prefix, suffix) | ||
| case _ => None | ||
| } | ||
| } | ||
| } | ||
|
|
||
| (extendedPrefixAndCounts, extendedPrefixAndSuffix) | ||
| } | ||
|
|
||
| /** | ||
| * calculate the patterns in local. | ||
| * Calculate the patterns in local. | ||
| * @param minCount the absolute minimum count | ||
| * @param data patterns and projected sequences data data | ||
| * @param data prefixes and projected sequences data data | ||
| * @return patterns | ||
| */ | ||
| private def getPatternsInLocal( | ||
| minCount: Long, | ||
| data: RDD[(Array[Int], Array[Array[Int]])]): RDD[(Array[Int], Long)] = { | ||
| data.flatMap { case (prefix, projDB) => | ||
| LocalPrefixSpan.run(minCount, maxPatternLength, prefix.toList, projDB) | ||
| .map { case (pattern: List[Int], count: Long) => (pattern.toArray.reverse, count) } | ||
| data: RDD[(List[Int], Iterable[Array[Int]])]): RDD[(List[Int], Long)] = { | ||
| data.flatMap { | ||
| case (prefix, projDB) => | ||
| LocalPrefixSpan.run(minCount, maxPatternLength, prefix.toList.reverse, projDB) | ||
| .map { case (pattern: List[Int], count: Long) => | ||
| (pattern.reverse, count) | ||
| } | ||
| } | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please leave a TODO to make it configurable with a better default value.
10000may be too small.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK