Skip to content

Commit b28bbff

Browse files
YY-OnCalljkbradley
authored andcommitted
[SPARK-20003][ML] FPGrowthModel setMinConfidence should affect rules generation and transform
## What changes were proposed in this pull request? jira: https://issues.apache.org/jira/browse/SPARK-20003 I was doing some test and found the issue. ml.fpm.FPGrowthModel `setMinConfidence` should always affect rules generation and transform. Currently associationRules in FPGrowthModel is a lazy val and `setMinConfidence` in FPGrowthModel has no impact once associationRules got computed . I try to cache the associationRules to avoid re-computation if `minConfidence` is not changed, but this makes FPGrowthModel somehow stateful. Let me know if there's any concern. ## How was this patch tested? new unit test and I strength the unit test for model save/load to ensure the cache mechanism. Author: Yuhao Yang <[email protected]> Closes #17336 from hhbyyh/fpmodelminconf.
1 parent a59759e commit b28bbff

File tree

2 files changed

+56
-21
lines changed

2 files changed

+56
-21
lines changed

mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -218,13 +218,28 @@ class FPGrowthModel private[ml] (
218218
def setPredictionCol(value: String): this.type = set(predictionCol, value)
219219

220220
/**
221-
* Get association rules fitted by AssociationRules using the minConfidence. Returns a dataframe
221+
* Cache minConfidence and associationRules to avoid redundant computation for association rules
222+
* during transform. The associationRules will only be re-computed when minConfidence changed.
223+
*/
224+
@transient private var _cachedMinConf: Double = Double.NaN
225+
226+
@transient private var _cachedRules: DataFrame = _
227+
228+
/**
229+
* Get association rules fitted using the minConfidence. Returns a dataframe
222230
* with three fields, "antecedent", "consequent" and "confidence", where "antecedent" and
223231
* "consequent" are Array[T] and "confidence" is Double.
224232
*/
225233
@Since("2.2.0")
226-
@transient lazy val associationRules: DataFrame = {
227-
AssociationRules.getAssociationRulesFromFP(freqItemsets, "items", "freq", $(minConfidence))
234+
@transient def associationRules: DataFrame = {
235+
if ($(minConfidence) == _cachedMinConf) {
236+
_cachedRules
237+
} else {
238+
_cachedRules = AssociationRules
239+
.getAssociationRulesFromFP(freqItemsets, "items", "freq", $(minConfidence))
240+
_cachedMinConf = $(minConfidence)
241+
_cachedRules
242+
}
228243
}
229244

230245
/**

mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala

Lines changed: 38 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
package org.apache.spark.ml.fpm
1818

1919
import org.apache.spark.SparkFunSuite
20-
import org.apache.spark.ml.util.DefaultReadWriteTest
20+
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
2121
import org.apache.spark.mllib.util.MLlibTestSparkContext
2222
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
2323
import org.apache.spark.sql.functions._
@@ -85,38 +85,58 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
8585
assert(prediction.select("prediction").where("id=3").first().getSeq[String](0).isEmpty)
8686
}
8787

88+
test("FPGrowth prediction should not contain duplicates") {
89+
// This should generate rule 1 -> 3, 2 -> 3
90+
val dataset = spark.createDataFrame(Seq(
91+
Array("1", "3"),
92+
Array("2", "3")
93+
).map(Tuple1(_))).toDF("items")
94+
val model = new FPGrowth().fit(dataset)
95+
96+
val prediction = model.transform(
97+
spark.createDataFrame(Seq(Tuple1(Array("1", "2")))).toDF("items")
98+
).first().getAs[Seq[String]]("prediction")
99+
100+
assert(prediction === Seq("3"))
101+
}
102+
103+
test("FPGrowthModel setMinConfidence should affect rules generation and transform") {
104+
val model = new FPGrowth().setMinSupport(0.1).setMinConfidence(0.1).fit(dataset)
105+
val oldRulesNum = model.associationRules.count()
106+
val oldPredict = model.transform(dataset)
107+
108+
model.setMinConfidence(0.8765)
109+
assert(oldRulesNum > model.associationRules.count())
110+
assert(!model.transform(dataset).collect().toSet.equals(oldPredict.collect().toSet))
111+
112+
// association rules should stay the same for same minConfidence
113+
model.setMinConfidence(0.1)
114+
assert(oldRulesNum === model.associationRules.count())
115+
assert(model.transform(dataset).collect().toSet.equals(oldPredict.collect().toSet))
116+
}
117+
88118
test("FPGrowth parameter check") {
89119
val fpGrowth = new FPGrowth().setMinSupport(0.4567)
90120
val model = fpGrowth.fit(dataset)
91121
.setMinConfidence(0.5678)
92122
assert(fpGrowth.getMinSupport === 0.4567)
93123
assert(model.getMinConfidence === 0.5678)
124+
MLTestingUtils.checkCopy(model)
94125
}
95126

96127
test("read/write") {
97128
def checkModelData(model: FPGrowthModel, model2: FPGrowthModel): Unit = {
98-
assert(model.freqItemsets.sort("items").collect() ===
99-
model2.freqItemsets.sort("items").collect())
129+
assert(model.freqItemsets.collect().toSet.equals(
130+
model2.freqItemsets.collect().toSet))
131+
assert(model.associationRules.collect().toSet.equals(
132+
model2.associationRules.collect().toSet))
133+
assert(model.setMinConfidence(0.9).associationRules.collect().toSet.equals(
134+
model2.setMinConfidence(0.9).associationRules.collect().toSet))
100135
}
101136
val fPGrowth = new FPGrowth()
102137
testEstimatorAndModelReadWrite(fPGrowth, dataset, FPGrowthSuite.allParamSettings,
103138
FPGrowthSuite.allParamSettings, checkModelData)
104139
}
105-
106-
test("FPGrowth prediction should not contain duplicates") {
107-
// This should generate rule 1 -> 3, 2 -> 3
108-
val dataset = spark.createDataFrame(Seq(
109-
Array("1", "3"),
110-
Array("2", "3")
111-
).map(Tuple1(_))).toDF("items")
112-
val model = new FPGrowth().fit(dataset)
113-
114-
val prediction = model.transform(
115-
spark.createDataFrame(Seq(Tuple1(Array("1", "2")))).toDF("items")
116-
).first().getAs[Seq[String]]("prediction")
117-
118-
assert(prediction === Seq("3"))
119-
}
120140
}
121141

122142
object FPGrowthSuite {

0 commit comments

Comments
 (0)