@@ -85,6 +85,36 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
8585 assert(prediction.select(" prediction" ).where(" id=3" ).first().getSeq[String ](0 ).isEmpty)
8686 }
8787
88+ test(" FPGrowth prediction should not contain duplicates" ) {
89+ // This should generate rule 1 -> 3, 2 -> 3
90+ val dataset = spark.createDataFrame(Seq (
91+ Array (" 1" , " 3" ),
92+ Array (" 2" , " 3" )
93+ ).map(Tuple1 (_))).toDF(" features" )
94+ val model = new FPGrowth ().fit(dataset)
95+
96+ val prediction = model.transform(
97+ spark.createDataFrame(Seq (Tuple1 (Array (" 1" , " 2" )))).toDF(" features" )
98+ ).first().getAs[Seq [String ]](" prediction" )
99+
100+ assert(prediction === Seq (" 3" ))
101+ }
102+
103+ test(" FPGrowthModel setMinConfidence should affect rules generation and transform" ) {
104+ val model = new FPGrowth ().setMinSupport(0.1 ).setMinConfidence(0.1 ).fit(dataset)
105+ val oldRulesNum = model.associationRules.count()
106+ assert(oldRulesNum == model.associationRules.count())
107+ val oldPredict = model.transform(dataset)
108+
109+ model.setMinConfidence(0.1 )
110+ assert(oldRulesNum === model.associationRules.count())
111+ assert(model.transform(dataset).collect().toSet.equals(oldPredict.collect().toSet))
112+
113+ model.setMinConfidence(0.8765 )
114+ assert(oldRulesNum > model.associationRules.count())
115+ assert(! model.transform(dataset).collect().toSet.equals(oldPredict.collect().toSet))
116+ }
117+
88118 test(" FPGrowth parameter check" ) {
89119 val fpGrowth = new FPGrowth ().setMinSupport(0.4567 )
90120 val model = fpGrowth.fit(dataset)
@@ -95,28 +125,17 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
95125
96126 test(" read/write" ) {
97127 def checkModelData (model : FPGrowthModel , model2 : FPGrowthModel ): Unit = {
98- assert(model.freqItemsets.sort(" items" ).collect() ===
99- model2.freqItemsets.sort(" items" ).collect())
128+ assert(model.freqItemsets.collect().toSet.equals(
129+ model2.freqItemsets.collect().toSet))
130+ assert(model.associationRules.collect().toSet.equals(
131+ model2.associationRules.collect().toSet))
132+ assert(model.setMinConfidence(0.9 ).associationRules.collect().toSet.equals(
133+ model2.setMinConfidence(0.9 ).associationRules.collect().toSet))
100134 }
101135 val fPGrowth = new FPGrowth ()
102136 testEstimatorAndModelReadWrite(fPGrowth, dataset, FPGrowthSuite .allParamSettings,
103137 FPGrowthSuite .allParamSettings, checkModelData)
104138 }
105-
106- test(" FPGrowth prediction should not contain duplicates" ) {
107- // This should generate rule 1 -> 3, 2 -> 3
108- val dataset = spark.createDataFrame(Seq (
109- Array (" 1" , " 3" ),
110- Array (" 2" , " 3" )
111- ).map(Tuple1 (_))).toDF(" features" )
112- val model = new FPGrowth ().fit(dataset)
113-
114- val prediction = model.transform(
115- spark.createDataFrame(Seq (Tuple1 (Array (" 1" , " 2" )))).toDF(" features" )
116- ).first().getAs[Seq [String ]](" prediction" )
117-
118- assert(prediction === Seq (" 3" ))
119- }
120139}
121140
122141object FPGrowthSuite {
0 commit comments