@@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.expressions.{If, Literal}
3333import org .apache .spark .sql .expressions .Aggregator
3434import org .apache .spark .sql .functions ._
3535import org .apache .spark .sql .types ._
36+ import org .apache .spark .storage .StorageLevel
3637import org .apache .spark .util .VersionUtils .majorMinorVersion
3738import org .apache .spark .util .collection .OpenHashMap
3839
@@ -205,6 +206,10 @@ class StringIndexer @Since("1.4.0") (
205206
206207 val (inputCols, _) = getInOutCols()
207208
209+ // If input dataset is not originally cached, we need to unpersist it
210+ // once we persist it later.
211+ val needUnpersist = dataset.storageLevel == StorageLevel .NONE
212+
208213 // In case of equal frequency when frequencyDesc/Asc, the strings are further sorted
209214 // alphabetically.
210215 val labelsArray = $(stringOrderType) match {
@@ -225,7 +230,9 @@ class StringIndexer @Since("1.4.0") (
225230 dataset.select(inputCol).na.drop().distinct().sort(dataset(s " $inputCol" ).desc)
226231 .as[String ].collect()
227232 }
228- dataset.unpersist()
233+ if (needUnpersist) {
234+ dataset.unpersist()
235+ }
229236 labels
230237 case StringIndexer .alphabetAsc =>
231238 import dataset .sparkSession .implicits ._
@@ -234,7 +241,9 @@ class StringIndexer @Since("1.4.0") (
234241 dataset.select(inputCol).na.drop().distinct().sort(dataset(s " $inputCol" ).asc)
235242 .as[String ].collect()
236243 }
237- dataset.unpersist()
244+ if (needUnpersist) {
245+ dataset.unpersist()
246+ }
238247 labels
239248 }
240249 copyValues(new StringIndexerModel (uid, labelsArray).setParent(this ))
@@ -309,12 +318,16 @@ class StringIndexerModel (
309318
310319 import StringIndexerModel ._
311320
321+ @ deprecated(" `this(labels: Array[String])` is deprecated and will be removed in 3.1.0. " +
322+ " Use `this(labelsArray: Array[Array[String]])` instead." , " 3.0.0" )
312323 @ Since (" 1.5.0" )
313324 def this (labels : Array [String ]) = this (Identifiable .randomUID(" strIdx" ), Array (labels))
314325
315326 @ Since (" 3.0.0" )
316327 def this (labelsArray : Array [Array [String ]]) = this (Identifiable .randomUID(" strIdx" ), labelsArray)
317328
329+ @ deprecated(" `labels` is deprecated and will be removed in 3.1.0. Use `labelsArray` " +
330+ " instead." , " 3.0.0" )
318331 @ Since (" 1.5.0" )
319332 def labels : Array [String ] = {
320333 require(labelsArray.length == 1 , " This StringIndexerModel is fitted by multi-columns, " +
0 commit comments