@@ -24,7 +24,7 @@ import org.apache.spark.annotation.Since
2424import org .apache .spark .ml .Model
2525import org .apache .spark .ml .attribute .NominalAttribute
2626import org .apache .spark .ml .param ._
27- import org .apache .spark .ml .param .shared .{HasInputCol , HasOutputCol }
27+ import org .apache .spark .ml .param .shared .{HasInputCol , HasInputCols , HasOutputCol }
2828import org .apache .spark .ml .util ._
2929import org .apache .spark .sql ._
3030import org .apache .spark .sql .expressions .UserDefinedFunction
@@ -140,6 +140,139 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String
140140 }
141141}
142142
143+ /**
144+ * `MultipleBucketizer` maps columns of continuous features to columns of feature buckets.
145+ */
146+ @ Since (" 2.3.0" )
147+ final class MultipleBucketizer @ Since (" 2.3.0" ) (@ Since (" 2.3.0" ) override val uid : String )
148+ extends Model [MultipleBucketizer ] with HasInputCols with DefaultParamsWritable {
149+
150+ @ Since (" 2.3.0" )
151+ def this () = this (Identifiable .randomUID(" multipleBucketizer" ))
152+
153+ /**
154+ * Parameter for mapping continuous features into buckets. With n+1 splits, there are n buckets.
155+ * A bucket defined by splits x,y holds values in the range [x,y) except the last bucket, which
156+ * also includes y. Splits should be of length greater than or equal to 3 and strictly increasing.
157+ * Values at -inf, inf must be explicitly provided to cover all Double values;
158+ * otherwise, values outside the splits specified will be treated as errors.
159+ *
160+ * See also [[handleInvalid ]], which can optionally create an additional bucket for NaN values.
161+ *
162+ * @group param
163+ */
164+ @ Since (" 2.3.0" )
165+ val splitsArray : DoubleArrayArrayParam = new DoubleArrayArrayParam (this , " splitsArray" ,
166+ " The array of split points for mapping continuous features into buckets for multiple " +
167+ " columns. For each input column, with n+1 splits, there are n buckets. A bucket defined by " +
168+ " splits x,y holds values in the range [x,y) except the last bucket, which also includes y. " +
169+ " The splits should be of length >= 3 and strictly increasing. Values at -inf, inf must be " +
170+ " explicitly provided to cover all Double values; otherwise, values outside the splits " +
171+ " specified will be treated as errors." ,
172+ Bucketizer .checkSplitsArray)
173+
174+ /**
175+ * Param for output column names.
176+ * @group param
177+ */
178+ @ Since (" 2.3.0" )
179+ final val outputCols : StringArrayParam = new StringArrayParam (this , " outputCols" ,
180+ " output column names" )
181+
182+ /** @group getParam */
183+ @ Since (" 2.3.0" )
184+ def getSplitsArray : Array [Array [Double ]] = $(splitsArray)
185+
186+ /** @group getParam */
187+ @ Since (" 2.3.0" )
188+ final def getOutputCols : Array [String ] = $(outputCols)
189+
190+ /** @group setParam */
191+ @ Since (" 2.3.0" )
192+ def setSplitsArray (value : Array [Array [Double ]]): this .type = set(splitsArray, value)
193+
194+ /** @group setParam */
195+ @ Since (" 2.3.0" )
196+ def setInputCols (value : Array [String ]): this .type = set(inputCols, value)
197+
198+ /** @group setParam */
199+ @ Since (" 2.3.0" )
200+ def setOutputCols (value : Array [String ]): this .type = set(outputCols, value)
201+
202+ /**
203+ * Param for how to handle invalid entries. Options are 'skip' (filter out rows with
204+ * invalid values), 'error' (throw an error), or 'keep' (keep invalid values in a special
205+ * additional bucket).
206+ * Default: "error"
207+ * @group param
208+ */
209+ // TODO: Make MultipleBucketizer inherit from HasHandleInvalid.
210+ @ Since (" 2.3.0" )
211+ val handleInvalid : Param [String ] = new Param [String ](this , " handleInvalid" , " how to handle " +
212+ " invalid entries. Options are skip (filter out rows with invalid values), " +
213+ " error (throw an error), or keep (keep invalid values in a special additional bucket)." ,
214+ ParamValidators .inArray(Bucketizer .supportedHandleInvalids))
215+
216+ /** @group getParam */
217+ @ Since (" 2.3.0" )
218+ def getHandleInvalid : String = $(handleInvalid)
219+
220+ /** @group setParam */
221+ @ Since (" 2.3.0" )
222+ def setHandleInvalid (value : String ): this .type = set(handleInvalid, value)
223+ setDefault(handleInvalid, Bucketizer .ERROR_INVALID )
224+
225+ @ Since (" 2.3.0" )
226+ override def transform (dataset : Dataset [_]): DataFrame = {
227+ transformSchema(dataset.schema)
228+ val (filteredDataset, keepInvalid) = {
229+ if (getHandleInvalid == Bucketizer .SKIP_INVALID ) {
230+ // "skip" NaN option is set, will filter out NaN values in the dataset
231+ (dataset.na.drop().toDF(), false )
232+ } else {
233+ (dataset.toDF(), getHandleInvalid == Bucketizer .KEEP_INVALID )
234+ }
235+ }
236+
237+ val bucketizers : Seq [UserDefinedFunction ] = $(splitsArray).map { splits =>
238+ udf { (feature : Double ) =>
239+ Bucketizer .binarySearchForBuckets(splits, feature, keepInvalid)
240+ }
241+ }
242+
243+ val newCols = $(inputCols).zipWithIndex.map { case (inputCol, idx) =>
244+ bucketizers(idx)(filteredDataset(inputCol))
245+ }
246+ val newFields = $(outputCols).zipWithIndex.map { case (outputCol, idx) =>
247+ prepOutputField(idx, outputCol)
248+ }
249+ filteredDataset.withColumns($(outputCols), newCols, newFields.map(_.metadata))
250+ }
251+
252+ private def prepOutputField (idx : Int , outputCol : String ): StructField = {
253+ val buckets = $(splitsArray)(idx).sliding(2 ).map(bucket => bucket.mkString(" , " )).toArray
254+ val attr = new NominalAttribute (name = Some (outputCol), isOrdinal = Some (true ),
255+ values = Some (buckets))
256+ attr.toStructField()
257+ }
258+
259+ @ Since (" 2.3.0" )
260+ override def transformSchema (schema : StructType ): StructType = {
261+ var transformedSchema = schema
262+ $(inputCols).zip($(outputCols)).zipWithIndex.map { case ((inputCol, outputCol), idx) =>
263+ SchemaUtils .checkColumnType(transformedSchema, inputCol, DoubleType )
264+ transformedSchema = SchemaUtils .appendColumn(transformedSchema,
265+ prepOutputField(idx, outputCol))
266+ }
267+ transformedSchema
268+ }
269+
270+ @ Since (" 2.3.0" )
271+ override def copy (extra : ParamMap ): MultipleBucketizer = {
272+ defaultCopy[MultipleBucketizer ](extra).setParent(parent)
273+ }
274+ }
275+
143276@ Since (" 1.6.0" )
144277object Bucketizer extends DefaultParamsReadable [Bucketizer ] {
145278
@@ -167,6 +300,13 @@ object Bucketizer extends DefaultParamsReadable[Bucketizer] {
167300 }
168301 }
169302
303+ /**
304+ * Check each splits in the splits array.
305+ */
306+ private [feature] def checkSplitsArray (splitsArray : Array [Array [Double ]]): Boolean = {
307+ splitsArray.forall(checkSplits(_))
308+ }
309+
170310 /**
171311 * Binary searching in several buckets to place each data point.
172312 * @param splits array of split points
@@ -211,3 +351,9 @@ object Bucketizer extends DefaultParamsReadable[Bucketizer] {
211351 @ Since (" 1.6.0" )
212352 override def load (path : String ): Bucketizer = super .load(path)
213353}
354+
355+ @ Since (" 2.3.0" )
356+ object MultipleBucketizer extends DefaultParamsReadable [MultipleBucketizer ] {
357+ @ Since (" 2.3.0" )
358+ override def load (path : String ): MultipleBucketizer = super .load(path)
359+ }
0 commit comments