@@ -19,7 +19,9 @@ package org.apache.spark.ml.fpm
1919
2020import scala .reflect .ClassTag
2121
22- import org .apache .hadoop .fs .{FileSystem , Path }
22+ import org .apache .hadoop .fs .Path
23+ import org .json4s .{DefaultFormats , JObject }
24+ import org .json4s .JsonDSL ._
2325
2426import org .apache .spark .annotation .{Experimental , Since }
2527import org .apache .spark .ml .{Estimator , Model }
@@ -175,7 +177,8 @@ class FPGrowth @Since("2.2.0") (
175177 if (handlePersistence) {
176178 items.persist(StorageLevel .MEMORY_AND_DISK )
177179 }
178-
180+ val inputRowCount = items.count()
181+ instr.logNumExamples(inputRowCount)
179182 val parentModel = mllibFP.run(items)
180183 val rows = parentModel.freqItemsets.map(f => Row (f.items, f.freq))
181184 val schema = StructType (Seq (
@@ -187,7 +190,8 @@ class FPGrowth @Since("2.2.0") (
187190 items.unpersist()
188191 }
189192
190- copyValues(new FPGrowthModel (uid, frequentItems, parentModel.itemSupport)).setParent(this )
193+ copyValues(new FPGrowthModel (uid, frequentItems, parentModel.itemSupport, inputRowCount))
194+ .setParent(this )
191195 }
192196
193197 @ Since (" 2.2.0" )
@@ -218,7 +222,8 @@ object FPGrowth extends DefaultParamsReadable[FPGrowth] {
218222class FPGrowthModel private [ml] (
219223 @ Since (" 2.2.0" ) override val uid : String ,
220224 @ Since (" 2.2.0" ) @ transient val freqItemsets : DataFrame ,
221- private val itemSupport : scala.collection.Map [Any , Double ])
225+ private val itemSupport : scala.collection.Map [Any , Double ],
226+ private val inputSize : Long )
222227 extends Model [FPGrowthModel ] with FPGrowthParams with MLWritable {
223228
224229 /** @group setParam */
@@ -302,7 +307,7 @@ class FPGrowthModel private[ml] (
302307
303308 @ Since (" 2.2.0" )
304309 override def copy (extra : ParamMap ): FPGrowthModel = {
305- val copied = new FPGrowthModel (uid, freqItemsets, itemSupport)
310+ val copied = new FPGrowthModel (uid, freqItemsets, itemSupport, inputSize )
306311 copyValues(copied, extra).setParent(this .parent)
307312 }
308313
@@ -324,23 +329,10 @@ object FPGrowthModel extends MLReadable[FPGrowthModel] {
324329 class FPGrowthModelWriter (instance : FPGrowthModel ) extends MLWriter {
325330
326331 override protected def saveImpl (path : String ): Unit = {
327- DefaultParamsWriter .saveMetadata(instance, path, sc)
332+ val extraMetadata : JObject = Map (" count" -> instance.inputSize)
333+ DefaultParamsWriter .saveMetadata(instance, path, sc, extraMetadata = Some (extraMetadata))
328334 val dataPath = new Path (path, " data" ).toString
329335 instance.freqItemsets.write.parquet(dataPath)
330- val itemDataType = instance.freqItemsets.schema(instance.getItemsCol).dataType match {
331- case ArrayType (et, _) => et
332- case other => throw new IllegalArgumentException (
333- s " Expected ${ArrayType .simpleString}, but got ${other.catalogString}. " )
334- }
335- val itemSupportPath = new Path (path, " itemSupport" ).toString
336- val itemSupportRows = instance.itemSupport.map {
337- case (item, support) => Row (item, support)
338- }.toSeq
339- val schema = StructType (Seq (
340- StructField (" item" , itemDataType, nullable = false ),
341- StructField (" support" , DoubleType , nullable = false )))
342- sparkSession.createDataFrame(sc.parallelize(itemSupportRows), schema)
343- .repartition(1 ).write.parquet(itemSupportPath)
344336 }
345337 }
346338
@@ -350,19 +342,17 @@ object FPGrowthModel extends MLReadable[FPGrowthModel] {
350342 private val className = classOf [FPGrowthModel ].getName
351343
352344 override def load (path : String ): FPGrowthModel = {
345+ implicit val format = DefaultFormats
353346 val metadata = DefaultParamsReader .loadMetadata(path, sc, className)
347+ val inputCount = (metadata.metadata \ " count" ).extract[Long ]
354348 val dataPath = new Path (path, " data" ).toString
355349 val frequentItems = sparkSession.read.parquet(dataPath)
356- val itemSupportPath = new Path (path, " itemSupport" )
357- val fs = FileSystem .get(sc.hadoopConfiguration)
358- val itemSupport = if (fs.exists(itemSupportPath)) {
359- sparkSession.read.parquet(itemSupportPath.toString).rdd.map {
360- case Row (item : Any , support : Double ) => item -> support
350+ val itemSupport = frequentItems.rdd.flatMap {
351+ case Row (items : Seq [_], count : Long ) if items.length == 1 =>
352+ Some (items.head -> count.toDouble / inputCount)
353+ case _ => None
361354 }.collectAsMap()
362- } else {
363- Map .empty[Any , Double ]
364- }
365- val model = new FPGrowthModel (metadata.uid, frequentItems, itemSupport)
355+ val model = new FPGrowthModel (metadata.uid, frequentItems, itemSupport, inputCount)
366356 metadata.getAndSetParams(model)
367357 model
368358 }
0 commit comments