Skip to content

Commit 957a6a2

Browse files
committed
compute itemSupport instead of saving them
1 parent 5970876 commit 957a6a2

File tree

1 file changed

+19
-29
lines changed

1 file changed

+19
-29
lines changed

mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala

Lines changed: 19 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@ package org.apache.spark.ml.fpm
1919

2020
import scala.reflect.ClassTag
2121

22-
import org.apache.hadoop.fs.{FileSystem, Path}
22+
import org.apache.hadoop.fs.Path
23+
import org.json4s.{DefaultFormats, JObject}
24+
import org.json4s.JsonDSL._
2325

2426
import org.apache.spark.annotation.{Experimental, Since}
2527
import org.apache.spark.ml.{Estimator, Model}
@@ -175,7 +177,8 @@ class FPGrowth @Since("2.2.0") (
175177
if (handlePersistence) {
176178
items.persist(StorageLevel.MEMORY_AND_DISK)
177179
}
178-
180+
val inputRowCount = items.count()
181+
instr.logNumExamples(inputRowCount)
179182
val parentModel = mllibFP.run(items)
180183
val rows = parentModel.freqItemsets.map(f => Row(f.items, f.freq))
181184
val schema = StructType(Seq(
@@ -187,7 +190,8 @@ class FPGrowth @Since("2.2.0") (
187190
items.unpersist()
188191
}
189192

190-
copyValues(new FPGrowthModel(uid, frequentItems, parentModel.itemSupport)).setParent(this)
193+
copyValues(new FPGrowthModel(uid, frequentItems, parentModel.itemSupport, inputRowCount))
194+
.setParent(this)
191195
}
192196

193197
@Since("2.2.0")
@@ -218,7 +222,8 @@ object FPGrowth extends DefaultParamsReadable[FPGrowth] {
218222
class FPGrowthModel private[ml] (
219223
@Since("2.2.0") override val uid: String,
220224
@Since("2.2.0") @transient val freqItemsets: DataFrame,
221-
private val itemSupport: scala.collection.Map[Any, Double])
225+
private val itemSupport: scala.collection.Map[Any, Double],
226+
private val inputSize: Long)
222227
extends Model[FPGrowthModel] with FPGrowthParams with MLWritable {
223228

224229
/** @group setParam */
@@ -302,7 +307,7 @@ class FPGrowthModel private[ml] (
302307

303308
@Since("2.2.0")
304309
override def copy(extra: ParamMap): FPGrowthModel = {
305-
val copied = new FPGrowthModel(uid, freqItemsets, itemSupport)
310+
val copied = new FPGrowthModel(uid, freqItemsets, itemSupport, inputSize)
306311
copyValues(copied, extra).setParent(this.parent)
307312
}
308313

@@ -324,23 +329,10 @@ object FPGrowthModel extends MLReadable[FPGrowthModel] {
324329
class FPGrowthModelWriter(instance: FPGrowthModel) extends MLWriter {
325330

326331
override protected def saveImpl(path: String): Unit = {
327-
DefaultParamsWriter.saveMetadata(instance, path, sc)
332+
val extraMetadata: JObject = Map("count" -> instance.inputSize)
333+
DefaultParamsWriter.saveMetadata(instance, path, sc, extraMetadata = Some(extraMetadata))
328334
val dataPath = new Path(path, "data").toString
329335
instance.freqItemsets.write.parquet(dataPath)
330-
val itemDataType = instance.freqItemsets.schema(instance.getItemsCol).dataType match {
331-
case ArrayType(et, _) => et
332-
case other => throw new IllegalArgumentException(
333-
s"Expected ${ArrayType.simpleString}, but got ${other.catalogString}.")
334-
}
335-
val itemSupportPath = new Path(path, "itemSupport").toString
336-
val itemSupportRows = instance.itemSupport.map {
337-
case (item, support) => Row(item, support)
338-
}.toSeq
339-
val schema = StructType(Seq(
340-
StructField("item", itemDataType, nullable = false),
341-
StructField("support", DoubleType, nullable = false)))
342-
sparkSession.createDataFrame(sc.parallelize(itemSupportRows), schema)
343-
.repartition(1).write.parquet(itemSupportPath)
344336
}
345337
}
346338

@@ -350,19 +342,17 @@ object FPGrowthModel extends MLReadable[FPGrowthModel] {
350342
private val className = classOf[FPGrowthModel].getName
351343

352344
override def load(path: String): FPGrowthModel = {
345+
implicit val format = DefaultFormats
353346
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
347+
val inputCount = (metadata.metadata \ "count").extract[Long]
354348
val dataPath = new Path(path, "data").toString
355349
val frequentItems = sparkSession.read.parquet(dataPath)
356-
val itemSupportPath = new Path(path, "itemSupport")
357-
val fs = FileSystem.get(sc.hadoopConfiguration)
358-
val itemSupport = if (fs.exists(itemSupportPath)) {
359-
sparkSession.read.parquet(itemSupportPath.toString).rdd.map {
360-
case Row(item: Any, support: Double) => item -> support
350+
val itemSupport = frequentItems.rdd.flatMap {
351+
case Row(items: Seq[_], count: Long) if items.length == 1 =>
352+
Some(items.head -> count.toDouble / inputCount)
353+
case _ => None
361354
}.collectAsMap()
362-
} else {
363-
Map.empty[Any, Double]
364-
}
365-
val model = new FPGrowthModel(metadata.uid, frequentItems, itemSupport)
355+
val model = new FPGrowthModel(metadata.uid, frequentItems, itemSupport, inputCount)
366356
metadata.getAndSetParams(model)
367357
model
368358
}

0 commit comments

Comments
 (0)