Skip to content

Commit 3398d62

Browse files
committed
set fpgrwothmodel minconf
1 parent 376d782 commit 3398d62

File tree

2 files changed

+49
-19
lines changed

2 files changed

+49
-19
lines changed

mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,14 +204,25 @@ class FPGrowthModel private[ml] (
204204
@Since("2.2.0")
205205
def setPredictionCol(value: String): this.type = set(predictionCol, value)
206206

207+
@transient private var _cachedMinConf: Double = Double.NaN
208+
209+
@transient private var _cachedRules: DataFrame = null
210+
207211
/**
208212
* Get association rules fitted by AssociationRules using the minConfidence. Returns a dataframe
209213
* with three fields, "antecedent", "consequent" and "confidence", where "antecedent" and
210214
* "consequent" are Array[T] and "confidence" is Double.
211215
*/
212216
@Since("2.2.0")
213-
@transient lazy val associationRules: DataFrame = {
214-
AssociationRules.getAssociationRulesFromFP(freqItemsets, "items", "freq", $(minConfidence))
217+
@transient def associationRules: DataFrame = {
218+
if ($(minConfidence) == _cachedMinConf) {
219+
_cachedRules
220+
} else {
221+
_cachedRules = AssociationRules
222+
.getAssociationRulesFromFP(freqItemsets, "items", "freq", $(minConfidence))
223+
_cachedMinConf = $(minConfidence)
224+
_cachedRules
225+
}
215226
}
216227

217228
/**

mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala

Lines changed: 36 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,36 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
8585
assert(prediction.select("prediction").where("id=3").first().getSeq[String](0).isEmpty)
8686
}
8787

88+
test("FPGrowth prediction should not contain duplicates") {
89+
// This should generate rule 1 -> 3, 2 -> 3
90+
val dataset = spark.createDataFrame(Seq(
91+
Array("1", "3"),
92+
Array("2", "3")
93+
).map(Tuple1(_))).toDF("features")
94+
val model = new FPGrowth().fit(dataset)
95+
96+
val prediction = model.transform(
97+
spark.createDataFrame(Seq(Tuple1(Array("1", "2")))).toDF("features")
98+
).first().getAs[Seq[String]]("prediction")
99+
100+
assert(prediction === Seq("3"))
101+
}
102+
103+
test("FPGrowthModel setMinConfidence should affect rules generation and transform") {
104+
val model = new FPGrowth().setMinSupport(0.1).setMinConfidence(0.1).fit(dataset)
105+
val oldRulesNum = model.associationRules.count()
106+
assert(oldRulesNum == model.associationRules.count())
107+
val oldPredict = model.transform(dataset)
108+
109+
model.setMinConfidence(0.1)
110+
assert(oldRulesNum === model.associationRules.count())
111+
assert(model.transform(dataset).collect().toSet.equals(oldPredict.collect().toSet))
112+
113+
model.setMinConfidence(0.8765)
114+
assert(oldRulesNum > model.associationRules.count())
115+
assert(!model.transform(dataset).collect().toSet.equals(oldPredict.collect().toSet))
116+
}
117+
88118
test("FPGrowth parameter check") {
89119
val fpGrowth = new FPGrowth().setMinSupport(0.4567)
90120
val model = fpGrowth.fit(dataset)
@@ -95,28 +125,17 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
95125

96126
test("read/write") {
97127
def checkModelData(model: FPGrowthModel, model2: FPGrowthModel): Unit = {
98-
assert(model.freqItemsets.sort("items").collect() ===
99-
model2.freqItemsets.sort("items").collect())
128+
assert(model.freqItemsets.collect().toSet.equals(
129+
model2.freqItemsets.collect().toSet))
130+
assert(model.associationRules.collect().toSet.equals(
131+
model2.associationRules.collect().toSet))
132+
assert(model.setMinConfidence(0.9).associationRules.collect().toSet.equals(
133+
model2.setMinConfidence(0.9).associationRules.collect().toSet))
100134
}
101135
val fPGrowth = new FPGrowth()
102136
testEstimatorAndModelReadWrite(fPGrowth, dataset, FPGrowthSuite.allParamSettings,
103137
FPGrowthSuite.allParamSettings, checkModelData)
104138
}
105-
106-
test("FPGrowth prediction should not contain duplicates") {
107-
// This should generate rule 1 -> 3, 2 -> 3
108-
val dataset = spark.createDataFrame(Seq(
109-
Array("1", "3"),
110-
Array("2", "3")
111-
).map(Tuple1(_))).toDF("features")
112-
val model = new FPGrowth().fit(dataset)
113-
114-
val prediction = model.transform(
115-
spark.createDataFrame(Seq(Tuple1(Array("1", "2")))).toDF("features")
116-
).first().getAs[Seq[String]]("prediction")
117-
118-
assert(prediction === Seq("3"))
119-
}
120139
}
121140

122141
object FPGrowthSuite {

0 commit comments

Comments
 (0)