Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion R/pkg/tests/fulltests/test_mllib_fpm.R
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ test_that("spark.fpGrowth", {
expected_association_rules <- data.frame(
antecedent = I(list(list("2"), list("3"))),
consequent = I(list(list("1"), list("1"))),
confidence = c(1, 1)
confidence = c(1, 1),
support = c(0.75, 0.5)
)

expect_equivalent(expected_association_rules, collect(spark.associationRules(model)))
Expand Down
58 changes: 38 additions & 20 deletions mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ package org.apache.spark.ml.fpm
import scala.reflect.ClassTag

import org.apache.hadoop.fs.Path
import org.json4s.DefaultFormats
import org.json4s.JsonDSL._

import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.{Estimator, Model}
Expand All @@ -34,6 +36,7 @@ import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.VersionUtils

/**
* Common params for FPGrowth and FPGrowthModel
Expand Down Expand Up @@ -187,7 +190,7 @@ class FPGrowth @Since("2.2.0") (
items.unpersist()
}

copyValues(new FPGrowthModel(uid, frequentItems)).setParent(this)
copyValues(new FPGrowthModel(uid, frequentItems, data.count())).setParent(this)
}

@Since("2.2.0")
Expand Down Expand Up @@ -217,7 +220,8 @@ object FPGrowth extends DefaultParamsReadable[FPGrowth] {
@Experimental
class FPGrowthModel private[ml] (
@Since("2.2.0") override val uid: String,
@Since("2.2.0") @transient val freqItemsets: DataFrame)
@Since("2.2.0") @transient val freqItemsets: DataFrame,
@Since("2.5.0") val numTrainingRecords: Long)
extends Model[FPGrowthModel] with FPGrowthParams with MLWritable {

/** @group setParam */
Expand All @@ -241,17 +245,17 @@ class FPGrowthModel private[ml] (
@transient private var _cachedRules: DataFrame = _

/**
* Get association rules fitted using the minConfidence. Returns a dataframe
* with three fields, "antecedent", "consequent" and "confidence", where "antecedent" and
* "consequent" are Array[T] and "confidence" is Double.
* Get association rules fitted by AssociationRules with the minConfidence. Returns a dataframe
* with four fields, "antecedent", "consequent", "confidence" and "support", where "antecedent"
* and "consequent" are Array[T], "confidence" and "support" are Double.
*/
@Since("2.2.0")
@transient def associationRules: DataFrame = {
if ($(minConfidence) == _cachedMinConf) {
_cachedRules
} else {
_cachedRules = AssociationRules
.getAssociationRulesFromFP(freqItemsets, "items", "freq", $(minConfidence))
_cachedRules = AssociationRules.getAssociationRulesFromFP(
freqItemsets, "items", "freq", numTrainingRecords, $(minConfidence))
_cachedMinConf = $(minConfidence)
_cachedRules
}
Expand Down Expand Up @@ -301,7 +305,7 @@ class FPGrowthModel private[ml] (

@Since("2.2.0")
override def copy(extra: ParamMap): FPGrowthModel = {
val copied = new FPGrowthModel(uid, freqItemsets)
val copied = new FPGrowthModel(uid, freqItemsets, numTrainingRecords)
copyValues(copied, extra).setParent(this.parent)
}

Expand All @@ -323,7 +327,8 @@ object FPGrowthModel extends MLReadable[FPGrowthModel] {
class FPGrowthModelWriter(instance: FPGrowthModel) extends MLWriter {

override protected def saveImpl(path: String): Unit = {
DefaultParamsWriter.saveMetadata(instance, path, sc)
val extraMetadata = "numTrainingRecords" -> instance.numTrainingRecords
DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata))
val dataPath = new Path(path, "data").toString
instance.freqItemsets.write.parquet(dataPath)
}
Expand All @@ -335,10 +340,20 @@ object FPGrowthModel extends MLReadable[FPGrowthModel] {
private val className = classOf[FPGrowthModel].getName

override def load(path: String): FPGrowthModel = {
implicit val format = DefaultFormats
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val (major, minor) = VersionUtils.majorMinorVersion(metadata.sparkVersion)
val numTrainingRecords = if (major.toInt < 2 || (major.toInt == 2 && minor.toInt < 4)) {
// 2.3 and before
1L
} else {
// 2.4+
(metadata.metadata \ "numTrainingRecords").extract[Long]
}

val dataPath = new Path(path, "data").toString
val frequentItems = sparkSession.read.parquet(dataPath)
val model = new FPGrowthModel(metadata.uid, frequentItems)
val model = new FPGrowthModel(metadata.uid, frequentItems, numTrainingRecords)
metadata.getAndSetParams(model)
model
}
Expand All @@ -352,29 +367,32 @@ private[fpm] object AssociationRules {
* @param dataset DataFrame("items"[Array], "freq"[Long]) containing frequent itemsets obtained
* from algorithms like [[FPGrowth]].
* @param itemsCol column name for frequent itemsets
* @param freqCol column name for appearance count of the frequent itemsets
* @param minConfidence minimum confidence for generating the association rules
* @return a DataFrame("antecedent"[Array], "consequent"[Array], "confidence"[Double])
* containing the association rules.
* @param freqCol column name for frequent itemsets count
* @param numTrainingRecords count of training Dataset, default -1.
* @param minConfidence minimum confidence for the result association rules
* @return a DataFrame("antecedent", "consequent", "confidence", "support") containing the
* association rules.
*/
def getAssociationRulesFromFP[T: ClassTag](
dataset: Dataset[_],
itemsCol: String,
freqCol: String,
minConfidence: Double): DataFrame = {
dataset: Dataset[_],
itemsCol: String,
freqCol: String,
numTrainingRecords: Long,
minConfidence: Double): DataFrame = {

val freqItemSetRdd = dataset.select(itemsCol, freqCol).rdd
.map(row => new FreqItemset(row.getSeq[T](0).toArray, row.getLong(1)))
val rows = new MLlibAssociationRules()
.setMinConfidence(minConfidence)
.run(freqItemSetRdd)
.map(r => Row(r.antecedent, r.consequent, r.confidence))
.map(r => Row(r.antecedent, r.consequent, r.confidence, r.freqUnion / numTrainingRecords))

val dt = dataset.schema(itemsCol).dataType
val schema = StructType(Seq(
StructField("antecedent", dt, nullable = false),
StructField("consequent", dt, nullable = false),
StructField("confidence", DoubleType, nullable = false)))
StructField("confidence", DoubleType, nullable = false),
StructField("support", DoubleType, nullable = false)))
val rules = dataset.sparkSession.createDataFrame(rows, schema)
rules
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ object AssociationRules {
class Rule[Item] private[fpm] (
@Since("1.5.0") val antecedent: Array[Item],
@Since("1.5.0") val consequent: Array[Item],
freqUnion: Double,
private[spark] val freqUnion: Double,
freqAntecedent: Double) extends Serializable {

/**
Expand Down
22 changes: 17 additions & 5 deletions mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,15 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
val data = dataset.withColumn("items", col("items").cast(ArrayType(dt)))
val model = new FPGrowth().setMinSupport(0.5).fit(data)
val generatedRules = model.setMinConfidence(0.5).associationRules
generatedRules.show()
val expectedRules = spark.createDataFrame(Seq(
(Array("2"), Array("1"), 1.0),
(Array("1"), Array("2"), 0.75)
)).toDF("antecedent", "consequent", "confidence")
(Array("2"), Array("1"), 1.0, 0.75),
(Array("1"), Array("2"), 0.75, 0.75)
)).toDF("antecedent", "consequent", "confidence", "support")
.withColumn("antecedent", col("antecedent").cast(ArrayType(dt)))
.withColumn("consequent", col("consequent").cast(ArrayType(dt)))
assert(expectedRules.sort("antecedent").rdd.collect().sameElements(
generatedRules.sort("antecedent").rdd.collect()))
assert(expectedRules.collect().toSet.equals(
generatedRules.collect().toSet))

val transformed = model.transform(data)
val expectedTransformed = spark.createDataFrame(Seq(
Expand Down Expand Up @@ -75,6 +76,17 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
assert(checkDF.count() == 3 && checkDF.filter(col("freq") === col("expectedFreq")).count() == 3)
}

test("FPGrowth associationRules") {
val model = new FPGrowth().setMinSupport(0.1).setMinConfidence(0.1).fit(dataset)
val expectedRules = spark.createDataFrame(Seq(
(Array("2"), Array("1"), 1.0, 0.75),
(Array("3"), Array("1"), 1.0, 0.25),
(Array("1"), Array("3"), 0.25, 0.25),
(Array("1"), Array("2"), 0.75, 0.75)
)).toDF("antecedent", "consequent", "confidence", "support")
assert(expectedRules.collect().toSet.equals(model.associationRules.collect().toSet))
}

test("FPGrowth getFreqItems with Null") {
val df = spark.createDataFrame(Seq(
(1, Array("1", "2", "3", "5")),
Expand Down
5 changes: 4 additions & 1 deletion project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,10 @@ object MimaExcludes {
ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.validationIndicatorCol"),

// [SPARK-23042] Use OneHotEncoderModel to encode labels in MultilayerPerceptronClassifier
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ml.classification.LabelConverter")
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ml.classification.LabelConverter"),

// [SPARK-19939][ML] Add support for association rules in ML
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.fpm.FPGrowthModel.this")
)

// Exclude rules for 2.3.x
Expand Down
38 changes: 19 additions & 19 deletions python/pyspark/ml/fpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,29 +187,29 @@ class FPGrowth(JavaEstimator, HasItemsCol, HasPredictionCol,
|[z] |
|[x, z, y, r, q, t, p] |
+------------------------+
>>> fp = FPGrowth(minSupport=0.2, minConfidence=0.7)
>>> fp = FPGrowth(minSupport=0.4, minConfidence=0.7)
>>> fpm = fp.fit(data)
>>> fpm.freqItemsets.show(5)
+---------+----+
| items|freq|
+---------+----+
| [s]| 3|
| [s, x]| 3|
|[s, x, z]| 2|
| [s, z]| 2|
| [r]| 3|
+---------+----+
+------+----+
| items|freq|
+------+----+
| [s]| 3|
|[s, x]| 3|
| [r]| 3|
| [y]| 3|
|[y, x]| 3|
+------+----+
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems to change the result quite a bit, is this expected?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

set a higher minSupport to avoid the 0.3333333... in the support column.

only showing top 5 rows
>>> fpm.associationRules.show(5)
+----------+----------+----------+
|antecedent|consequent|confidence|
+----------+----------+----------+
| [t, s]| [y]| 1.0|
| [t, s]| [x]| 1.0|
| [t, s]| [z]| 1.0|
| [p]| [r]| 1.0|
| [p]| [z]| 1.0|
+----------+----------+----------+
+----------+----------+----------+-------+
|antecedent|consequent|confidence|support|
+----------+----------+----------+-------+
| [t]| [y]| 1.0| 0.5|
| [t]| [x]| 1.0| 0.5|
| [t]| [z]| 1.0| 0.5|
| [y, t, x]| [z]| 1.0| 0.5|
| [x]| [s]| 0.75| 0.5|
+----------+----------+----------+-------+
only showing top 5 rows
>>> new_data = spark.createDataFrame([(["t", "s"], )], ["items"])
>>> sorted(fpm.transform(new_data).first().prediction)
Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/ml/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -2158,8 +2158,8 @@ def test_association_rules(self):
fpm = fp.fit(self.data)

expected_association_rules = self.spark.createDataFrame(
[([3], [1], 1.0), ([2], [1], 1.0)],
["antecedent", "consequent", "confidence"]
[([3], [1], 1.0, 0.5), ([2], [1], 1.0, 0.75)],
["antecedent", "consequent", "confidence", "support"]
)
actual_association_rules = fpm.associationRules

Expand Down