-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-10697][ML] Add lift to Association rules #22236
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
92c32dd
7f052b8
4c8b7be
5970876
957a6a2
88eb571
44a0021
1b4e3b3
706303f
2407e05
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,6 +20,8 @@ package org.apache.spark.ml.fpm | |
| import scala.reflect.ClassTag | ||
|
|
||
| import org.apache.hadoop.fs.Path | ||
| import org.json4s.{DefaultFormats, JObject} | ||
| import org.json4s.JsonDSL._ | ||
|
|
||
| import org.apache.spark.annotation.{Experimental, Since} | ||
| import org.apache.spark.ml.{Estimator, Model} | ||
|
|
@@ -34,6 +36,7 @@ import org.apache.spark.sql._ | |
| import org.apache.spark.sql.functions._ | ||
| import org.apache.spark.sql.types._ | ||
| import org.apache.spark.storage.StorageLevel | ||
| import org.apache.spark.util.VersionUtils | ||
|
|
||
| /** | ||
| * Common params for FPGrowth and FPGrowthModel | ||
|
|
@@ -175,7 +178,8 @@ class FPGrowth @Since("2.2.0") ( | |
| if (handlePersistence) { | ||
| items.persist(StorageLevel.MEMORY_AND_DISK) | ||
| } | ||
|
|
||
| val inputRowCount = items.count() | ||
| instr.logNumExamples(inputRowCount) | ||
| val parentModel = mllibFP.run(items) | ||
| val rows = parentModel.freqItemsets.map(f => Row(f.items, f.freq)) | ||
| val schema = StructType(Seq( | ||
|
|
@@ -187,7 +191,8 @@ class FPGrowth @Since("2.2.0") ( | |
| items.unpersist() | ||
| } | ||
|
|
||
| copyValues(new FPGrowthModel(uid, frequentItems)).setParent(this) | ||
| copyValues(new FPGrowthModel(uid, frequentItems, parentModel.itemSupport, inputRowCount)) | ||
| .setParent(this) | ||
| } | ||
|
|
||
| @Since("2.2.0") | ||
|
|
@@ -217,7 +222,9 @@ object FPGrowth extends DefaultParamsReadable[FPGrowth] { | |
| @Experimental | ||
| class FPGrowthModel private[ml] ( | ||
| @Since("2.2.0") override val uid: String, | ||
| @Since("2.2.0") @transient val freqItemsets: DataFrame) | ||
| @Since("2.2.0") @transient val freqItemsets: DataFrame, | ||
| private val itemSupport: scala.collection.Map[Any, Double], | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I suppose there's no way of adding the item generic type here; it's really in the schema of the DataFrame. Does If you have the support for every item, do you need the overall count here as well? the item counts have already been divided through by numTrainingRecords here. Below only itemSupport is really passed somewhere else.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, The |
||
| private val numTrainingRecords: Long) | ||
| extends Model[FPGrowthModel] with FPGrowthParams with MLWritable { | ||
|
|
||
| /** @group setParam */ | ||
|
|
@@ -241,17 +248,17 @@ class FPGrowthModel private[ml] ( | |
| @transient private var _cachedRules: DataFrame = _ | ||
|
|
||
| /** | ||
| * Get association rules fitted using the minConfidence. Returns a dataframe | ||
| * with three fields, "antecedent", "consequent" and "confidence", where "antecedent" and | ||
| * "consequent" are Array[T] and "confidence" is Double. | ||
| * Get association rules fitted using the minConfidence. Returns a dataframe with four fields, | ||
| * "antecedent", "consequent", "confidence" and "lift", where "antecedent" and "consequent" are | ||
| * Array[T], whereas "confidence" and "lift" are Double. | ||
| */ | ||
| @Since("2.2.0") | ||
| @transient def associationRules: DataFrame = { | ||
| if ($(minConfidence) == _cachedMinConf) { | ||
| _cachedRules | ||
| } else { | ||
| _cachedRules = AssociationRules | ||
| .getAssociationRulesFromFP(freqItemsets, "items", "freq", $(minConfidence)) | ||
| .getAssociationRulesFromFP(freqItemsets, "items", "freq", $(minConfidence), itemSupport) | ||
| _cachedMinConf = $(minConfidence) | ||
| _cachedRules | ||
| } | ||
|
|
@@ -301,7 +308,7 @@ class FPGrowthModel private[ml] ( | |
|
|
||
| @Since("2.2.0") | ||
| override def copy(extra: ParamMap): FPGrowthModel = { | ||
| val copied = new FPGrowthModel(uid, freqItemsets) | ||
| val copied = new FPGrowthModel(uid, freqItemsets, itemSupport, numTrainingRecords) | ||
| copyValues(copied, extra).setParent(this.parent) | ||
| } | ||
|
|
||
|
|
@@ -323,7 +330,8 @@ object FPGrowthModel extends MLReadable[FPGrowthModel] { | |
| class FPGrowthModelWriter(instance: FPGrowthModel) extends MLWriter { | ||
|
|
||
| override protected def saveImpl(path: String): Unit = { | ||
| DefaultParamsWriter.saveMetadata(instance, path, sc) | ||
| val extraMetadata: JObject = Map("numTrainingRecords" -> instance.numTrainingRecords) | ||
| DefaultParamsWriter.saveMetadata(instance, path, sc, extraMetadata = Some(extraMetadata)) | ||
| val dataPath = new Path(path, "data").toString | ||
| instance.freqItemsets.write.parquet(dataPath) | ||
| } | ||
|
|
@@ -335,10 +343,28 @@ object FPGrowthModel extends MLReadable[FPGrowthModel] { | |
| private val className = classOf[FPGrowthModel].getName | ||
|
|
||
| override def load(path: String): FPGrowthModel = { | ||
| implicit val format = DefaultFormats | ||
| val metadata = DefaultParamsReader.loadMetadata(path, sc, className) | ||
| val (major, minor) = VersionUtils.majorMinorVersion(metadata.sparkVersion) | ||
| val numTrainingRecords = if (major.toInt < 2 || (major.toInt == 2 && minor.toInt < 4)) { | ||
| // 2.3 and before don't store the count | ||
| 0L | ||
| } else { | ||
| // 2.4+ | ||
| (metadata.metadata \ "numTrainingRecords").extract[Long] | ||
| } | ||
| val dataPath = new Path(path, "data").toString | ||
| val frequentItems = sparkSession.read.parquet(dataPath) | ||
| val model = new FPGrowthModel(metadata.uid, frequentItems) | ||
| val itemSupport = if (numTrainingRecords == 0L) { | ||
| Map.empty[Any, Double] | ||
| } else { | ||
| frequentItems.rdd.flatMap { | ||
| case Row(items: Seq[_], count: Long) if items.length == 1 => | ||
| Some(items.head -> count.toDouble / numTrainingRecords) | ||
| case _ => None | ||
| }.collectAsMap() | ||
| } | ||
| val model = new FPGrowthModel(metadata.uid, frequentItems, itemSupport, numTrainingRecords) | ||
| metadata.getAndSetParams(model) | ||
| model | ||
| } | ||
|
|
@@ -354,27 +380,30 @@ private[fpm] object AssociationRules { | |
| * @param itemsCol column name for frequent itemsets | ||
| * @param freqCol column name for appearance count of the frequent itemsets | ||
| * @param minConfidence minimum confidence for generating the association rules | ||
| * @return a DataFrame("antecedent"[Array], "consequent"[Array], "confidence"[Double]) | ||
| * containing the association rules. | ||
| * @param itemSupport map containing an item and its support | ||
| * @return a DataFrame("antecedent"[Array], "consequent"[Array], "confidence"[Double], | ||
| * "lift" [Double]) containing the association rules. | ||
| */ | ||
| def getAssociationRulesFromFP[T: ClassTag]( | ||
| dataset: Dataset[_], | ||
| itemsCol: String, | ||
| freqCol: String, | ||
| minConfidence: Double): DataFrame = { | ||
| minConfidence: Double, | ||
| itemSupport: scala.collection.Map[T, Double]): DataFrame = { | ||
|
|
||
| val freqItemSetRdd = dataset.select(itemsCol, freqCol).rdd | ||
| .map(row => new FreqItemset(row.getSeq[T](0).toArray, row.getLong(1))) | ||
| val rows = new MLlibAssociationRules() | ||
| .setMinConfidence(minConfidence) | ||
| .run(freqItemSetRdd) | ||
| .map(r => Row(r.antecedent, r.consequent, r.confidence)) | ||
| .run(freqItemSetRdd, itemSupport) | ||
| .map(r => Row(r.antecedent, r.consequent, r.confidence, r.lift.orNull)) | ||
|
|
||
| val dt = dataset.schema(itemsCol).dataType | ||
| val schema = StructType(Seq( | ||
| StructField("antecedent", dt, nullable = false), | ||
| StructField("consequent", dt, nullable = false), | ||
| StructField("confidence", DoubleType, nullable = false))) | ||
| StructField("confidence", DoubleType, nullable = false), | ||
| StructField("lift", DoubleType))) | ||
| val rules = dataset.sparkSession.createDataFrame(rows, schema) | ||
| rules | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -56,11 +56,24 @@ class AssociationRules private[fpm] ( | |
| /** | ||
| * Computes the association rules with confidence above `minConfidence`. | ||
| * @param freqItemsets frequent itemset model obtained from [[FPGrowth]] | ||
| * @return a `Set[Rule[Item]]` containing the association rules. | ||
| * @return a `RDD[Rule[Item]]` containing the association rules. | ||
| * | ||
| */ | ||
| @Since("1.5.0") | ||
| def run[Item: ClassTag](freqItemsets: RDD[FreqItemset[Item]]): RDD[Rule[Item]] = { | ||
| run(freqItemsets, Map.empty[Item, Double]) | ||
| } | ||
|
|
||
| /** | ||
| * Computes the association rules with confidence above `minConfidence`. | ||
| * @param freqItemsets frequent itemset model obtained from [[FPGrowth]] | ||
| * @param itemSupport map containing an item and its support | ||
| * @return a `RDD[Rule[Item]]` containing the association rules. The rules will be able to | ||
| * compute also the lift metric. | ||
| */ | ||
| @Since("2.4.0") | ||
| def run[Item: ClassTag](freqItemsets: RDD[FreqItemset[Item]], | ||
| itemSupport: scala.collection.Map[Item, Double]): RDD[Rule[Item]] = { | ||
| // For candidate rule X => Y, generate (X, (Y, freq(X union Y))) | ||
| val candidates = freqItemsets.flatMap { itemset => | ||
| val items = itemset.items | ||
|
|
@@ -76,8 +89,13 @@ class AssociationRules private[fpm] ( | |
| // Join to get (X, ((Y, freq(X union Y)), freq(X))), generate rules, and filter by confidence | ||
| candidates.join(freqItemsets.map(x => (x.items.toSeq, x.freq))) | ||
| .map { case (antecendent, ((consequent, freqUnion), freqAntecedent)) => | ||
| new Rule(antecendent.toArray, consequent.toArray, freqUnion, freqAntecedent) | ||
| }.filter(_.confidence >= minConfidence) | ||
| new Rule(antecendent.toArray, | ||
| consequent.toArray, | ||
| freqUnion, | ||
| freqAntecedent, | ||
| // the consequent contains always only one element | ||
| itemSupport.get(consequent.head)) | ||
| }.filter(_.confidence >= minConfidence) | ||
| } | ||
|
|
||
| /** | ||
|
|
@@ -107,14 +125,21 @@ object AssociationRules { | |
| @Since("1.5.0") val antecedent: Array[Item], | ||
| @Since("1.5.0") val consequent: Array[Item], | ||
| freqUnion: Double, | ||
| freqAntecedent: Double) extends Serializable { | ||
| freqAntecedent: Double, | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ideally these frequencies would have been Longs I think, but too late. Yes, stay consistent. |
||
| freqConsequent: Option[Double]) extends Serializable { | ||
|
|
||
| /** | ||
| * Returns the confidence of the rule. | ||
| * | ||
| */ | ||
| @Since("1.5.0") | ||
| def confidence: Double = freqUnion.toDouble / freqAntecedent | ||
| def confidence: Double = freqUnion / freqAntecedent | ||
|
|
||
| /** | ||
| * Returns the lift of the rule. | ||
| */ | ||
| @Since("2.4.0") | ||
| def lift: Option[Double] = freqConsequent.map(fCons => confidence / fCons) | ||
|
|
||
| require(antecedent.toSet.intersect(consequent.toSet).isEmpty, { | ||
| val sharedItems = antecedent.toSet.intersect(consequent.toSet) | ||
|
|
@@ -142,7 +167,7 @@ object AssociationRules { | |
|
|
||
| override def toString: String = { | ||
| s"${antecedent.mkString("{", ",", "}")} => " + | ||
| s"${consequent.mkString("{", ",", "}")}: ${confidence}" | ||
| s"${consequent.mkString("{", ",", "}")}: (confidence: $confidence; lift: $lift)" | ||
| } | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -36,6 +36,10 @@ object MimaExcludes { | |
|
|
||
| // Exclude rules for 2.4.x | ||
| lazy val v24excludes = v23excludes ++ Seq( | ||
| // [SPARK-10697][ML] Add lift to Association rules | ||
| ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.fpm.FPGrowthModel.this"), | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These are for the private[ml] constructors right? OK to suppress, yes
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, they are the private ones. |
||
| ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.fpm.AssociationRules#Rule.this"), | ||
|
|
||
| // [SPARK-25044] Address translation of LMF closure primitive args to Object in Scala 2.12 | ||
| ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.expressions.UserDefinedFunction$"), | ||
| ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.apply"), | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, so the total count n won't just equal the sum of the count of consequents, for example, because the frequent item set was pruned of infrequent sets? Darn, yeah you need to deal with the case where you know n and where you don't.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, there is no way to know the total size of the training dataset from the frequent itemsets unfortunately... So yes, we need to deal with it unfortunately.