Skip to content

Commit 0137d67

Browse files
committed
Unpersist if input is not originally cached. Add deprecated info.
1 parent d6fed35 commit 0137d67

File tree

1 file changed

+15
-2
lines changed

1 file changed

+15
-2
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.expressions.{If, Literal}
3333
import org.apache.spark.sql.expressions.Aggregator
3434
import org.apache.spark.sql.functions._
3535
import org.apache.spark.sql.types._
36+
import org.apache.spark.storage.StorageLevel
3637
import org.apache.spark.util.VersionUtils.majorMinorVersion
3738
import org.apache.spark.util.collection.OpenHashMap
3839

@@ -205,6 +206,10 @@ class StringIndexer @Since("1.4.0") (
205206

206207
val (inputCols, _) = getInOutCols()
207208

209+
// If input dataset is not originally cached, we need to unpersist it
210+
// once we persist it later.
211+
val needUnpersist = dataset.storageLevel == StorageLevel.NONE
212+
208213
// In case of equal frequency when frequencyDesc/Asc, the strings are further sorted
209214
// alphabetically.
210215
val labelsArray = $(stringOrderType) match {
@@ -225,7 +230,9 @@ class StringIndexer @Since("1.4.0") (
225230
dataset.select(inputCol).na.drop().distinct().sort(dataset(s"$inputCol").desc)
226231
.as[String].collect()
227232
}
228-
dataset.unpersist()
233+
if (needUnpersist) {
234+
dataset.unpersist()
235+
}
229236
labels
230237
case StringIndexer.alphabetAsc =>
231238
import dataset.sparkSession.implicits._
@@ -234,7 +241,9 @@ class StringIndexer @Since("1.4.0") (
234241
dataset.select(inputCol).na.drop().distinct().sort(dataset(s"$inputCol").asc)
235242
.as[String].collect()
236243
}
237-
dataset.unpersist()
244+
if (needUnpersist) {
245+
dataset.unpersist()
246+
}
238247
labels
239248
}
240249
copyValues(new StringIndexerModel(uid, labelsArray).setParent(this))
@@ -309,12 +318,16 @@ class StringIndexerModel (
309318

310319
import StringIndexerModel._
311320

321+
@deprecated("`this(labels: Array[String])` is deprecated and will be removed in 3.1.0. " +
322+
"Use `this(labelsArray: Array[Array[String]])` instead.", "3.0.0")
312323
@Since("1.5.0")
313324
def this(labels: Array[String]) = this(Identifiable.randomUID("strIdx"), Array(labels))
314325

315326
@Since("3.0.0")
316327
def this(labelsArray: Array[Array[String]]) = this(Identifiable.randomUID("strIdx"), labelsArray)
317328

329+
@deprecated("`labels` is deprecated and will be removed in 3.1.0. Use `labelsArray` " +
330+
"instead.", "3.0.0")
318331
@Since("1.5.0")
319332
def labels: Array[String] = {
320333
require(labelsArray.length == 1, "This StringIndexerModel is fitted by multi-columns, " +

0 commit comments

Comments
 (0)