Skip to content

Commit 9580aa9

Browse files
committed
add support to association rules
1 parent a0ad662 commit 9580aa9

File tree

2 files changed

+37
-19
lines changed

2 files changed

+37
-19
lines changed

mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,11 @@
1717

1818
package org.apache.spark.ml.fpm
1919

20-
import scala.collection.mutable.ArrayBuffer
2120
import scala.reflect.ClassTag
2221

2322
import org.apache.hadoop.fs.Path
23+
import org.json4s.DefaultFormats
24+
import org.json4s.JsonDSL._
2425

2526
import org.apache.spark.annotation.{Experimental, Since}
2627
import org.apache.spark.ml.{Estimator, Model}
@@ -190,7 +191,7 @@ object FPGrowth extends DefaultParamsReadable[FPGrowth] {
190191
class FPGrowthModel private[ml] (
191192
@Since("2.2.0") override val uid: String,
192193
@transient val freqItemsets: DataFrame,
193-
val numTotalRecords: Long)
194+
val numTrainingRecords: Long)
194195
extends Model[FPGrowthModel] with FPGrowthParams with MLWritable {
195196

196197
/** @group setParam */
@@ -212,7 +213,8 @@ class FPGrowthModel private[ml] (
212213
*/
213214
@Since("2.2.0")
214215
@transient lazy val associationRules: DataFrame = {
215-
AssociationRules.getAssociationRulesFromFP(freqItemsets, "items", "freq", numTotalRecords, $(minConfidence))
216+
AssociationRules.getAssociationRulesFromFP(
217+
freqItemsets, "items", "freq", numTrainingRecords, $(minConfidence))
216218
}
217219

218220
/**
@@ -260,7 +262,7 @@ class FPGrowthModel private[ml] (
260262

261263
@Since("2.2.0")
262264
override def copy(extra: ParamMap): FPGrowthModel = {
263-
val copied = new FPGrowthModel(uid, freqItemsets)
265+
val copied = new FPGrowthModel(uid, freqItemsets, numTrainingRecords)
264266
copyValues(copied, extra).setParent(this.parent)
265267
}
266268

@@ -282,7 +284,8 @@ object FPGrowthModel extends MLReadable[FPGrowthModel] {
282284
class FPGrowthModelWriter(instance: FPGrowthModel) extends MLWriter {
283285

284286
override protected def saveImpl(path: String): Unit = {
285-
DefaultParamsWriter.saveMetadata(instance, path, sc)
287+
val extraMetadata = "numTrainingRecords" -> instance.numTrainingRecords
288+
DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata))
286289
val dataPath = new Path(path, "data").toString
287290
instance.freqItemsets.write.parquet(dataPath)
288291
}
@@ -295,9 +298,11 @@ object FPGrowthModel extends MLReadable[FPGrowthModel] {
295298

296299
override def load(path: String): FPGrowthModel = {
297300
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
301+
implicit val format = DefaultFormats
302+
val numTrainingRecords = (metadata.metadata \ "numTrainingRecords").extract[Long]
298303
val dataPath = new Path(path, "data").toString
299304
val frequentItems = sparkSession.read.parquet(dataPath)
300-
val model = new FPGrowthModel(metadata.uid, frequentItems)
305+
val model = new FPGrowthModel(metadata.uid, frequentItems, numTrainingRecords)
301306
DefaultParamsReader.getAndSetParams(model, metadata)
302307
model
303308
}
@@ -312,25 +317,24 @@ private[fpm] object AssociationRules {
312317
* algorithms like [[FPGrowth]].
313318
* @param itemsCol column name for frequent itemsets
314319
* @param freqCol column name for frequent itemsets count
320+
* @param numTrainingRecords count of training Dataset
315321
* @param minConfidence minimum confidence for the result association rules
316322
* @return a DataFrame("antecedent", "consequent", "confidence") containing the association
317323
* rules.
318324
*/
319325
def getAssociationRulesFromFP[T: ClassTag](
320-
dataset: Dataset[_],
321-
itemsCol: String,
322-
freqCol: String,
323-
numTotalRecords: Long,
324-
minConfidence: Double): DataFrame = {
326+
dataset: Dataset[_],
327+
itemsCol: String,
328+
freqCol: String,
329+
numTrainingRecords: Long,
330+
minConfidence: Double): DataFrame = {
325331

326332
val freqItemSetRdd = dataset.select(itemsCol, freqCol).rdd
327333
.map(row => new FreqItemset(row.getSeq[T](0).toArray, row.getLong(1)))
328334
val rows = new MLlibAssociationRules()
329335
.setMinConfidence(minConfidence)
330336
.run(freqItemSetRdd)
331-
.map(r => Row(r.antecedent, r.consequent, r.confidence, r.freqUnion / numTotalRecords))
332-
333-
337+
.map(r => Row(r.antecedent, r.consequent, r.confidence, r.freqUnion / numTrainingRecords))
334338

335339
val dt = dataset.schema(itemsCol).dataType
336340
val schema = StructType(Seq(

mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,13 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
3838
val model = new FPGrowth().setMinSupport(0.5).fit(data)
3939
val generatedRules = model.setMinConfidence(0.5).associationRules
4040
val expectedRules = spark.createDataFrame(Seq(
41-
(Array("2"), Array("1"), 1.0),
42-
(Array("1"), Array("2"), 0.75)
43-
)).toDF("antecedent", "consequent", "confidence")
41+
(Array("2"), Array("1"), 1.0, 0.75),
42+
(Array("1"), Array("2"), 0.75, 0.75)
43+
)).toDF("antecedent", "consequent", "confidence", "support")
4444
.withColumn("antecedent", col("antecedent").cast(ArrayType(dt)))
4545
.withColumn("consequent", col("consequent").cast(ArrayType(dt)))
46-
assert(expectedRules.sort("antecedent").rdd.collect().sameElements(
47-
generatedRules.sort("antecedent").rdd.collect()))
46+
assert(expectedRules.collect().toSet.equals(
47+
generatedRules.collect().toSet))
4848

4949
val transformed = model.transform(data)
5050
val expectedTransformed = spark.createDataFrame(Seq(
@@ -74,6 +74,20 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
7474
assert(checkDF.count() == 3 && checkDF.filter(col("freq") === col("expectedFreq")).count() == 3)
7575
}
7676

77+
test("FPGrowth associationRules") {
78+
val freqItemsets = spark.createDataFrame(Seq(
79+
(Array("2"), 4L),
80+
(Array("1"), 2L),
81+
(Array("1", "2"), 2L)
82+
)).toDF("items", "freq")
83+
val model = new FPGrowthModel("fpgrowth", freqItemsets, 4L).setMinConfidence(0.1)
84+
val expectedRules = spark.createDataFrame(Seq(
85+
(Array("2"), Array("1"), 0.5, 0.5),
86+
(Array("1"), Array("2"), 1.0, 0.5)
87+
)).toDF("antecedent", "consequent", "confidence", "support")
88+
assert(expectedRules.collect().toSet.equals(model.associationRules.collect().toSet))
89+
}
90+
7791
test("FPGrowth getFreqItems with Null") {
7892
val df = spark.createDataFrame(Seq(
7993
(1, Array("1", "2", "3", "5")),

0 commit comments

Comments
 (0)