From 70b93e32d2c1e4b7d09197a136404db661a0f95a Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Sun, 12 Jul 2015 16:48:27 -0700 Subject: [PATCH 1/5] Performance improvements in LocalPrefixSpan, fix tests --- .../spark/mllib/fpm/LocalPrefixSpan.scala | 48 ++++++++++--------- .../apache/spark/mllib/fpm/PrefixSpan.scala | 6 ++- .../spark/mllib/fpm/PrefixSpanSuite.scala | 14 ++---- 3 files changed, 32 insertions(+), 36 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala index 39c48b084e550..892bfa61403e0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala @@ -20,6 +20,8 @@ package org.apache.spark.mllib.fpm import org.apache.spark.Logging import org.apache.spark.annotation.Experimental +import scala.collection.mutable.ArrayBuffer + /** * * :: Experimental :: @@ -42,22 +44,20 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable { def run( minCount: Long, maxPatternLength: Int, - prefix: Array[Int], - projectedDatabase: Array[Array[Int]]): Array[(Array[Int], Long)] = { + prefix: ArrayBuffer[Int], + projectedDatabase: Array[Array[Int]]): Iterator[(Array[Int], Long)] = { val frequentPrefixAndCounts = getFreqItemAndCounts(minCount, projectedDatabase) val frequentPatternAndCounts = frequentPrefixAndCounts - .map(x => (prefix ++ Array(x._1), x._2)) + .map(x => ((prefix :+ x._1).toArray, x._2)) val prefixProjectedDatabases = getPatternAndProjectedDatabase( prefix, frequentPrefixAndCounts.map(_._1), projectedDatabase) - val continueProcess = prefixProjectedDatabases.nonEmpty && prefix.length + 1 < maxPatternLength - if (continueProcess) { - val nextPatterns = prefixProjectedDatabases - .map(x => run(minCount, maxPatternLength, x._1, x._2)) - .reduce(_ ++ _) - frequentPatternAndCounts ++ nextPatterns + if (prefixProjectedDatabases.nonEmpty && prefix.length + 1 < maxPatternLength) { + frequentPatternAndCounts.iterator ++ prefixProjectedDatabases.flatMap { + case (nextPrefix, projDB) => run(minCount, maxPatternLength, nextPrefix, projDB) + } } else { - frequentPatternAndCounts + frequentPatternAndCounts.iterator } } @@ -86,28 +86,30 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable { minCount: Long, sequences: Array[Array[Int]]): Array[(Int, Long)] = { sequences.flatMap(_.distinct) - .groupBy(x => x) - .mapValues(_.length.toLong) + .foldRight(Map[Int, Long]().withDefaultValue(0L)) { case (item, ctr) => + ctr + (item -> (ctr(item) + 1)) + } .filter(_._2 >= minCount) .toArray } /** * Get the frequent prefixes' projected database. - * @param prePrefix the frequent prefixes' prefix - * @param frequentPrefixes frequent prefixes - * @param sequences sequences data - * @return prefixes and projected database + * @param prefix the frequent prefixes' prefix + * @param frequentPrefixes frequent next prefixes + * @param projDB projected database for given prefix + * @return extensions of prefix by one item and corresponding projected databases */ private def getPatternAndProjectedDatabase( - prePrefix: Array[Int], + prefix: ArrayBuffer[Int], frequentPrefixes: Array[Int], - sequences: Array[Array[Int]]): Array[(Array[Int], Array[Array[Int]])] = { - val filteredProjectedDatabase = sequences - .map(x => x.filter(frequentPrefixes.contains(_))) - frequentPrefixes.map { x => - val sub = filteredProjectedDatabase.map(y => getSuffix(x, y)).filter(_.nonEmpty) - (prePrefix ++ Array(x), sub) + projDB: Array[Array[Int]]): Array[(ArrayBuffer[Int], Array[Array[Int]])] = { + val filteredProjectedDatabase = projDB.map(x => x.filter(frequentPrefixes.contains(_))) + frequentPrefixes.map { nextItem => + val nextProjDB = filteredProjectedDatabase + .map(candidateSeq => getSuffix(nextItem, candidateSeq)) + .filter(_.nonEmpty) + (prefix :+ nextItem, nextProjDB) }.filter(x => x._2.nonEmpty) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala index 9d8c60ef0fc45..0bccb37fa9cd5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala @@ -22,6 +22,8 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel +import scala.collection.mutable.ArrayBuffer + /** * * :: Experimental :: @@ -150,8 +152,8 @@ class PrefixSpan private ( private def getPatternsInLocal( minCount: Long, data: RDD[(Array[Int], Array[Array[Int]])]): RDD[(Array[Int], Long)] = { - data.flatMap { x => - LocalPrefixSpan.run(minCount, maxPatternLength, x._1, x._2) + data.flatMap { case (prefix, projDB) => + LocalPrefixSpan.run(minCount, maxPatternLength, prefix.to[ArrayBuffer], projDB) } } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala index 413436d3db85f..87b87569e2ec9 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala @@ -18,9 +18,8 @@ package org.apache.spark.mllib.fpm import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.rdd.RDD -class PrefixspanSuite extends SparkFunSuite with MLlibTestSparkContext { +class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { test("PrefixSpan using Integer type") { @@ -48,15 +47,8 @@ class PrefixspanSuite extends SparkFunSuite with MLlibTestSparkContext { def compareResult( expectedValue: Array[(Array[Int], Long)], actualValue: Array[(Array[Int], Long)]): Boolean = { - val sortedExpectedValue = expectedValue.sortWith{ (x, y) => - x._1.mkString(",") + ":" + x._2 < y._1.mkString(",") + ":" + y._2 - } - val sortedActualValue = actualValue.sortWith{ (x, y) => - x._1.mkString(",") + ":" + x._2 < y._1.mkString(",") + ":" + y._2 - } - sortedExpectedValue.zip(sortedActualValue) - .map(x => x._1._1.mkString(",") == x._2._1.mkString(",") && x._1._2 == x._2._2) - .reduce(_&&_) + expectedValue.map(x => (x._1.toList, x._2)).toSet == + actualValue.map(x => (x._1.toList, x._2)).toSet } val prefixspan = new PrefixSpan() From 2e00cba1ef52eed77d6df1b2acfafb741ac427ef Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Mon, 13 Jul 2015 14:50:52 -0700 Subject: [PATCH 2/5] Depth first projections --- .../spark/mllib/fpm/LocalPrefixSpan.scala | 79 ++++++++----------- .../apache/spark/mllib/fpm/PrefixSpan.scala | 2 +- 2 files changed, 35 insertions(+), 46 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala index 892bfa61403e0..53760645a53af 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala @@ -20,8 +20,6 @@ package org.apache.spark.mllib.fpm import org.apache.spark.Logging import org.apache.spark.annotation.Experimental -import scala.collection.mutable.ArrayBuffer - /** * * :: Experimental :: @@ -36,7 +34,7 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable { * @param minCount minimum count * @param maxPatternLength maximum pattern length * @param prefix prefix - * @param projectedDatabase the projected dabase + * @param database the projected dabase * @return a set of sequential pattern pairs, * the key of pair is sequential pattern (a list of items), * the value of pair is the pattern's count. @@ -44,31 +42,36 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable { def run( minCount: Long, maxPatternLength: Int, - prefix: ArrayBuffer[Int], - projectedDatabase: Array[Array[Int]]): Iterator[(Array[Int], Long)] = { - val frequentPrefixAndCounts = getFreqItemAndCounts(minCount, projectedDatabase) - val frequentPatternAndCounts = frequentPrefixAndCounts - .map(x => ((prefix :+ x._1).toArray, x._2)) - val prefixProjectedDatabases = getPatternAndProjectedDatabase( - prefix, frequentPrefixAndCounts.map(_._1), projectedDatabase) + prefix: List[Int], + database: Iterable[Array[Int]]): Iterator[(Array[Int], Long)] = { + + if (database.isEmpty) return Iterator.empty + + val frequentItemAndCounts = getFreqItemAndCounts(minCount, database) + val frequentItems = frequentItemAndCounts.map(_._1) + val frequentPatternAndCounts = frequentItemAndCounts + .map { case (item, count) => ((item :: prefix).reverse.toArray, count) } - if (prefixProjectedDatabases.nonEmpty && prefix.length + 1 < maxPatternLength) { - frequentPatternAndCounts.iterator ++ prefixProjectedDatabases.flatMap { - case (nextPrefix, projDB) => run(minCount, maxPatternLength, nextPrefix, projDB) + val filteredProjectedDatabase = database.map(x => x.filter(frequentItems.contains(_))) + + if (prefix.length + 1 < maxPatternLength) { + frequentPatternAndCounts ++ frequentItems.flatMap { item => + val nextProjected = project(filteredProjectedDatabase, item) + run(minCount, maxPatternLength, item :: prefix, nextProjected) } } else { - frequentPatternAndCounts.iterator + frequentPatternAndCounts } } /** - * calculate suffix sequence following a prefix in a sequence - * @param prefix prefix - * @param sequence sequence + * Calculate suffix sequence immediately after the first occurrence of an item. + * @param item item to get suffix after + * @param sequence sequence to extract suffix from * @return suffix sequence */ - def getSuffix(prefix: Int, sequence: Array[Int]): Array[Int] = { - val index = sequence.indexOf(prefix) + def getSuffix(item: Int, sequence: Array[Int]): Array[Int] = { + val index = sequence.indexOf(item) if (index == -1) { Array() } else { @@ -76,40 +79,26 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable { } } + def project(database: Iterable[Array[Int]], prefix: Int): Iterable[Array[Int]] = { + database + .map(candidateSeq => getSuffix(prefix, candidateSeq)) + .filter(_.nonEmpty) + } + /** * Generates frequent items by filtering the input data using minimal count level. - * @param minCount the absolute minimum count - * @param sequences sequences data - * @return array of item and count pair + * @param minCount the minimum count for an item to be frequent + * @param database database of sequences + * @return item and count pairs */ private def getFreqItemAndCounts( minCount: Long, - sequences: Array[Array[Int]]): Array[(Int, Long)] = { - sequences.flatMap(_.distinct) + database: Iterable[Array[Int]]): Iterator[(Int, Long)] = { + database.flatMap(_.distinct) .foldRight(Map[Int, Long]().withDefaultValue(0L)) { case (item, ctr) => ctr + (item -> (ctr(item) + 1)) } .filter(_._2 >= minCount) - .toArray - } - - /** - * Get the frequent prefixes' projected database. - * @param prefix the frequent prefixes' prefix - * @param frequentPrefixes frequent next prefixes - * @param projDB projected database for given prefix - * @return extensions of prefix by one item and corresponding projected databases - */ - private def getPatternAndProjectedDatabase( - prefix: ArrayBuffer[Int], - frequentPrefixes: Array[Int], - projDB: Array[Array[Int]]): Array[(ArrayBuffer[Int], Array[Array[Int]])] = { - val filteredProjectedDatabase = projDB.map(x => x.filter(frequentPrefixes.contains(_))) - frequentPrefixes.map { nextItem => - val nextProjDB = filteredProjectedDatabase - .map(candidateSeq => getSuffix(nextItem, candidateSeq)) - .filter(_.nonEmpty) - (prefix :+ nextItem, nextProjDB) - }.filter(x => x._2.nonEmpty) + .iterator } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala index 0bccb37fa9cd5..73ba3bb63dfcb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala @@ -153,7 +153,7 @@ class PrefixSpan private ( minCount: Long, data: RDD[(Array[Int], Array[Array[Int]])]): RDD[(Array[Int], Long)] = { data.flatMap { case (prefix, projDB) => - LocalPrefixSpan.run(minCount, maxPatternLength, prefix.to[ArrayBuffer], projDB) + LocalPrefixSpan.run(minCount, maxPatternLength, prefix.toList, projDB) } } } From f055d829a683e6a0336d26847e93ad2ee05c1c8a Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Mon, 13 Jul 2015 16:55:37 -0700 Subject: [PATCH 3/5] Fix failing scalatest --- .../org/apache/spark/mllib/fpm/LocalPrefixSpan.scala | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala index 53760645a53af..6a418dcc6fe82 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala @@ -48,19 +48,19 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable { if (database.isEmpty) return Iterator.empty val frequentItemAndCounts = getFreqItemAndCounts(minCount, database) - val frequentItems = frequentItemAndCounts.map(_._1) + val frequentItems = frequentItemAndCounts.map(_._1).toSet val frequentPatternAndCounts = frequentItemAndCounts .map { case (item, count) => ((item :: prefix).reverse.toArray, count) } val filteredProjectedDatabase = database.map(x => x.filter(frequentItems.contains(_))) if (prefix.length + 1 < maxPatternLength) { - frequentPatternAndCounts ++ frequentItems.flatMap { item => + frequentPatternAndCounts.iterator ++ frequentItems.flatMap { item => val nextProjected = project(filteredProjectedDatabase, item) run(minCount, maxPatternLength, item :: prefix, nextProjected) } } else { - frequentPatternAndCounts + frequentPatternAndCounts.iterator } } @@ -93,12 +93,11 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable { */ private def getFreqItemAndCounts( minCount: Long, - database: Iterable[Array[Int]]): Iterator[(Int, Long)] = { + database: Iterable[Array[Int]]): Iterable[(Int, Long)] = { database.flatMap(_.distinct) .foldRight(Map[Int, Long]().withDefaultValue(0L)) { case (item, ctr) => ctr + (item -> (ctr(item) + 1)) } .filter(_._2 >= minCount) - .iterator } } From 921225661d15b0b1d0fa10927906f5341eeb4e27 Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Tue, 14 Jul 2015 17:05:56 -0700 Subject: [PATCH 4/5] MengXR code review comments --- .../spark/mllib/fpm/LocalPrefixSpan.scala | 22 +++++++++---------- .../apache/spark/mllib/fpm/PrefixSpan.scala | 3 +-- 2 files changed, 11 insertions(+), 14 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala index 6a418dcc6fe82..307034f7cd607 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala @@ -17,16 +17,13 @@ package org.apache.spark.mllib.fpm +import scala.collection.mutable + import org.apache.spark.Logging -import org.apache.spark.annotation.Experimental /** - * - * :: Experimental :: - * * Calculate all patterns of a projected database in local. */ -@Experimental private[fpm] object LocalPrefixSpan extends Logging with Serializable { /** @@ -43,18 +40,18 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable { minCount: Long, maxPatternLength: Int, prefix: List[Int], - database: Iterable[Array[Int]]): Iterator[(Array[Int], Long)] = { + database: Array[Array[Int]]): Iterator[(List[Int], Long)] = { if (database.isEmpty) return Iterator.empty val frequentItemAndCounts = getFreqItemAndCounts(minCount, database) val frequentItems = frequentItemAndCounts.map(_._1).toSet val frequentPatternAndCounts = frequentItemAndCounts - .map { case (item, count) => ((item :: prefix).reverse.toArray, count) } + .map { case (item, count) => ((item :: prefix), count) } - val filteredProjectedDatabase = database.map(x => x.filter(frequentItems.contains(_))) if (prefix.length + 1 < maxPatternLength) { + val filteredProjectedDatabase = database.map(x => x.filter(frequentItems.contains(_))) frequentPatternAndCounts.iterator ++ frequentItems.flatMap { item => val nextProjected = project(filteredProjectedDatabase, item) run(minCount, maxPatternLength, item :: prefix, nextProjected) @@ -79,7 +76,7 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable { } } - def project(database: Iterable[Array[Int]], prefix: Int): Iterable[Array[Int]] = { + def project(database: Array[Array[Int]], prefix: Int): Array[Array[Int]] = { database .map(candidateSeq => getSuffix(prefix, candidateSeq)) .filter(_.nonEmpty) @@ -93,10 +90,11 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable { */ private def getFreqItemAndCounts( minCount: Long, - database: Iterable[Array[Int]]): Iterable[(Int, Long)] = { + database: Array[Array[Int]]): Iterable[(Int, Long)] = { database.flatMap(_.distinct) - .foldRight(Map[Int, Long]().withDefaultValue(0L)) { case (item, ctr) => - ctr + (item -> (ctr(item) + 1)) + .foldRight(mutable.Map[Int, Long]().withDefaultValue(0L)) { case (item, ctr) => + ctr(item) += 1 + ctr } .filter(_._2 >= minCount) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala index 73ba3bb63dfcb..6f52db7b073ae 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala @@ -22,8 +22,6 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel -import scala.collection.mutable.ArrayBuffer - /** * * :: Experimental :: @@ -154,6 +152,7 @@ class PrefixSpan private ( data: RDD[(Array[Int], Array[Array[Int]])]): RDD[(Array[Int], Long)] = { data.flatMap { case (prefix, projDB) => LocalPrefixSpan.run(minCount, maxPatternLength, prefix.toList, projDB) + .map { case (pattern: List[Int], count: Long) => (pattern.toArray.reverse, count) } } } } From 91e43577045088c23e08855322c6ff5624421cee Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 14 Jul 2015 21:12:26 -0700 Subject: [PATCH 5/5] update LocalPrefixSpan impl --- .../spark/mllib/fpm/LocalPrefixSpan.scala | 49 ++++++++----------- .../spark/mllib/fpm/PrefixSpanSuite.scala | 4 +- 2 files changed, 23 insertions(+), 30 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala index 307034f7cd607..7ead6327486cc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala @@ -30,34 +30,25 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable { * Calculate all patterns of a projected database. * @param minCount minimum count * @param maxPatternLength maximum pattern length - * @param prefix prefix - * @param database the projected dabase + * @param prefixes prefixes in reversed order + * @param database the projected database * @return a set of sequential pattern pairs, - * the key of pair is sequential pattern (a list of items), + * the key of pair is sequential pattern (a list of items in reversed order), * the value of pair is the pattern's count. */ def run( minCount: Long, maxPatternLength: Int, - prefix: List[Int], + prefixes: List[Int], database: Array[Array[Int]]): Iterator[(List[Int], Long)] = { - - if (database.isEmpty) return Iterator.empty - + if (prefixes.length == maxPatternLength || database.isEmpty) return Iterator.empty val frequentItemAndCounts = getFreqItemAndCounts(minCount, database) - val frequentItems = frequentItemAndCounts.map(_._1).toSet - val frequentPatternAndCounts = frequentItemAndCounts - .map { case (item, count) => ((item :: prefix), count) } - - - if (prefix.length + 1 < maxPatternLength) { - val filteredProjectedDatabase = database.map(x => x.filter(frequentItems.contains(_))) - frequentPatternAndCounts.iterator ++ frequentItems.flatMap { item => - val nextProjected = project(filteredProjectedDatabase, item) - run(minCount, maxPatternLength, item :: prefix, nextProjected) - } - } else { - frequentPatternAndCounts.iterator + val filteredDatabase = database.map(x => x.filter(frequentItemAndCounts.contains)) + frequentItemAndCounts.iterator.flatMap { case (item, count) => + val newPrefixes = item :: prefixes + val newProjected = project(filteredDatabase, item) + Iterator.single((newPrefixes, count)) ++ + run(minCount, maxPatternLength, newPrefixes, newProjected) } } @@ -78,7 +69,7 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable { def project(database: Array[Array[Int]], prefix: Int): Array[Array[Int]] = { database - .map(candidateSeq => getSuffix(prefix, candidateSeq)) + .map(getSuffix(prefix, _)) .filter(_.nonEmpty) } @@ -86,16 +77,18 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable { * Generates frequent items by filtering the input data using minimal count level. * @param minCount the minimum count for an item to be frequent * @param database database of sequences - * @return item and count pairs + * @return freq item to count map */ private def getFreqItemAndCounts( minCount: Long, - database: Array[Array[Int]]): Iterable[(Int, Long)] = { - database.flatMap(_.distinct) - .foldRight(mutable.Map[Int, Long]().withDefaultValue(0L)) { case (item, ctr) => - ctr(item) += 1 - ctr + database: Array[Array[Int]]): mutable.Map[Int, Long] = { + // TODO: use PrimitiveKeyOpenHashMap + val counts = mutable.Map[Int, Long]().withDefaultValue(0L) + database.foreach { sequence => + sequence.distinct.foreach { item => + counts(item) += 1L } - .filter(_._2 >= minCount) + } + counts.filter(_._2 >= minCount) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala index 87b87569e2ec9..9f107c89f6d80 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala @@ -47,8 +47,8 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { def compareResult( expectedValue: Array[(Array[Int], Long)], actualValue: Array[(Array[Int], Long)]): Boolean = { - expectedValue.map(x => (x._1.toList, x._2)).toSet == - actualValue.map(x => (x._1.toList, x._2)).toSet + expectedValue.map(x => (x._1.toSeq, x._2)).toSet == + actualValue.map(x => (x._1.toSeq, x._2)).toSet } val prefixspan = new PrefixSpan()