1717
1818package org .apache .spark .ml .fpm
1919
20- import scala .collection .mutable .ArrayBuffer
2120import scala .reflect .ClassTag
2221
2322import org .apache .hadoop .fs .Path
23+ import org .json4s .DefaultFormats
24+ import org .json4s .JsonDSL ._
2425
2526import org .apache .spark .annotation .{Experimental , Since }
2627import org .apache .spark .ml .{Estimator , Model }
@@ -190,7 +191,7 @@ object FPGrowth extends DefaultParamsReadable[FPGrowth] {
190191class FPGrowthModel private [ml] (
191192 @ Since (" 2.2.0" ) override val uid : String ,
192193 @ transient val freqItemsets : DataFrame ,
193- val numTotalRecords : Long )
194+ val numTrainingRecords : Long )
194195 extends Model [FPGrowthModel ] with FPGrowthParams with MLWritable {
195196
196197 /** @group setParam */
@@ -212,7 +213,8 @@ class FPGrowthModel private[ml] (
212213 */
213214 @ Since (" 2.2.0" )
214215 @ transient lazy val associationRules : DataFrame = {
215- AssociationRules .getAssociationRulesFromFP(freqItemsets, " items" , " freq" , numTotalRecords, $(minConfidence))
216+ AssociationRules .getAssociationRulesFromFP(
217+ freqItemsets, " items" , " freq" , numTrainingRecords, $(minConfidence))
216218 }
217219
218220 /**
@@ -260,7 +262,7 @@ class FPGrowthModel private[ml] (
260262
261263 @ Since (" 2.2.0" )
262264 override def copy (extra : ParamMap ): FPGrowthModel = {
263- val copied = new FPGrowthModel (uid, freqItemsets)
265+ val copied = new FPGrowthModel (uid, freqItemsets, numTrainingRecords )
264266 copyValues(copied, extra).setParent(this .parent)
265267 }
266268
@@ -282,7 +284,8 @@ object FPGrowthModel extends MLReadable[FPGrowthModel] {
282284 class FPGrowthModelWriter (instance : FPGrowthModel ) extends MLWriter {
283285
284286 override protected def saveImpl (path : String ): Unit = {
285- DefaultParamsWriter .saveMetadata(instance, path, sc)
287+ val extraMetadata = " numTrainingRecords" -> instance.numTrainingRecords
288+ DefaultParamsWriter .saveMetadata(instance, path, sc, Some (extraMetadata))
286289 val dataPath = new Path (path, " data" ).toString
287290 instance.freqItemsets.write.parquet(dataPath)
288291 }
@@ -295,9 +298,11 @@ object FPGrowthModel extends MLReadable[FPGrowthModel] {
295298
296299 override def load (path : String ): FPGrowthModel = {
297300 val metadata = DefaultParamsReader .loadMetadata(path, sc, className)
301+ implicit val format = DefaultFormats
302+ val numTrainingRecords = (metadata.metadata \ " numTrainingRecords" ).extract[Long ]
298303 val dataPath = new Path (path, " data" ).toString
299304 val frequentItems = sparkSession.read.parquet(dataPath)
300- val model = new FPGrowthModel (metadata.uid, frequentItems)
305+ val model = new FPGrowthModel (metadata.uid, frequentItems, numTrainingRecords )
301306 DefaultParamsReader .getAndSetParams(model, metadata)
302307 model
303308 }
@@ -312,25 +317,24 @@ private[fpm] object AssociationRules {
312317 * algorithms like [[FPGrowth ]].
313318 * @param itemsCol column name for frequent itemsets
314319 * @param freqCol column name for frequent itemsets count
320+ * @param numTrainingRecords count of training Dataset
315321 * @param minConfidence minimum confidence for the result association rules
316322 * @return a DataFrame("antecedent", "consequent", "confidence") containing the association
317323 * rules.
318324 */
319325 def getAssociationRulesFromFP [T : ClassTag ](
320- dataset : Dataset [_],
321- itemsCol : String ,
322- freqCol : String ,
323- numTotalRecords : Long ,
324- minConfidence : Double ): DataFrame = {
326+ dataset : Dataset [_],
327+ itemsCol : String ,
328+ freqCol : String ,
329+ numTrainingRecords : Long ,
330+ minConfidence : Double ): DataFrame = {
325331
326332 val freqItemSetRdd = dataset.select(itemsCol, freqCol).rdd
327333 .map(row => new FreqItemset (row.getSeq[T ](0 ).toArray, row.getLong(1 )))
328334 val rows = new MLlibAssociationRules ()
329335 .setMinConfidence(minConfidence)
330336 .run(freqItemSetRdd)
331- .map(r => Row (r.antecedent, r.consequent, r.confidence, r.freqUnion / numTotalRecords))
332-
333-
337+ .map(r => Row (r.antecedent, r.consequent, r.confidence, r.freqUnion / numTrainingRecords))
334338
335339 val dt = dataset.schema(itemsCol).dataType
336340 val schema = StructType (Seq (
0 commit comments