@@ -26,21 +26,23 @@ import org.apache.hadoop.fs.Path
2626import org .apache .spark .SparkException
2727import org .apache .spark .annotation .Since
2828import org .apache .spark .internal .Logging
29+ import org .apache .spark .ml .feature ._
2930import org .apache .spark .ml .linalg ._
30- import org .apache .spark .ml .optim .aggregator .HingeAggregator
31+ import org .apache .spark .ml .optim .aggregator ._
3132import org .apache .spark .ml .optim .loss .{L2Regularization , RDDLossFunction }
3233import org .apache .spark .ml .param ._
3334import org .apache .spark .ml .param .shared ._
3435import org .apache .spark .ml .stat ._
3536import org .apache .spark .ml .util ._
3637import org .apache .spark .ml .util .Instrumentation .instrumented
38+ import org .apache .spark .rdd .RDD
3739import org .apache .spark .sql .{Dataset , Row }
3840import org .apache .spark .storage .StorageLevel
3941
4042/** Params for linear SVM Classifier. */
4143private [classification] trait LinearSVCParams extends ClassifierParams with HasRegParam
4244 with HasMaxIter with HasFitIntercept with HasTol with HasStandardization with HasWeightCol
43- with HasAggregationDepth with HasThreshold {
45+ with HasAggregationDepth with HasThreshold with HasBlockSize {
4446
4547 /**
4648 * Param for threshold in binary classification prediction.
@@ -154,31 +156,65 @@ class LinearSVC @Since("2.2.0") (
154156 def setAggregationDepth (value : Int ): this .type = set(aggregationDepth, value)
155157 setDefault(aggregationDepth -> 2 )
156158
159+ /**
160+ * Set block size for stacking input data in matrices.
161+ * If blockSize == 1, then stacking will be skipped, and each vector is treated individually;
162+ * If blockSize > 1, then vectors will be stacked to blocks, and high-level BLAS routines
163+ * will be used if possible (for example, GEMV instead of DOT, GEMM instead of GEMV).
164+ * Recommended size is between 10 and 1000. An appropriate choice of the block size depends
165+ * on the sparsity and dim of input datasets, the underlying BLAS implementation (for example,
166+ * f2jBLAS, OpenBLAS, intel MKL) and its configuration (for example, number of threads).
167+ * Note that existing BLAS implementations are mainly optimized for dense matrices, if the
168+ * input dataset is sparse, stacking may bring no performance gain, the worse is possible
169+ * performance regression.
170+ * Default is 1.
171+ *
172+ * @group expertSetParam
173+ */
174+ @ Since (" 3.1.0" )
175+ def setBlockSize (value : Int ): this .type = set(blockSize, value)
176+ setDefault(blockSize -> 1 )
177+
157178 @ Since (" 2.2.0" )
158179 override def copy (extra : ParamMap ): LinearSVC = defaultCopy(extra)
159180
160181 override protected def train (dataset : Dataset [_]): LinearSVCModel = instrumented { instr =>
161- val handlePersistence = dataset.storageLevel == StorageLevel .NONE
162-
163- val instances = extractInstances(dataset)
164- if (handlePersistence) instances.persist(StorageLevel .MEMORY_AND_DISK )
165-
166182 instr.logPipelineStage(this )
167183 instr.logDataset(dataset)
168184 instr.logParams(this , labelCol, weightCol, featuresCol, predictionCol, rawPredictionCol,
169- regParam, maxIter, fitIntercept, tol, standardization, threshold, aggregationDepth)
185+ regParam, maxIter, fitIntercept, tol, standardization, threshold, aggregationDepth, blockSize)
186+
187+ val instances = extractInstances(dataset)
188+ .setName(" training instances" )
170189
171- val (summarizer, labelSummarizer) =
190+ val (summarizer, labelSummarizer) = if ($(blockSize) == 1 ) {
191+ if (dataset.storageLevel == StorageLevel .NONE ) {
192+ instances.persist(StorageLevel .MEMORY_AND_DISK )
193+ }
172194 Summarizer .getClassificationSummarizers(instances, $(aggregationDepth))
173- instr.logNumExamples(summarizer.count)
174- instr.logNamedValue(" lowestLabelWeight" , labelSummarizer.histogram.min.toString)
175- instr.logNamedValue(" highestLabelWeight" , labelSummarizer.histogram.max.toString)
176- instr.logSumOfWeights(summarizer.weightSum)
195+ } else {
196+ // instances will be standardized and converted to blocks, so no need to cache instances.
197+ Summarizer .getClassificationSummarizers(instances, $(aggregationDepth),
198+ Seq (" mean" , " std" , " count" , " numNonZeros" ))
199+ }
177200
178201 val histogram = labelSummarizer.histogram
179202 val numInvalid = labelSummarizer.countInvalid
180203 val numFeatures = summarizer.mean.size
181- val numFeaturesPlusIntercept = if (getFitIntercept) numFeatures + 1 else numFeatures
204+
205+ instr.logNumExamples(summarizer.count)
206+ instr.logNamedValue(" lowestLabelWeight" , labelSummarizer.histogram.min.toString)
207+ instr.logNamedValue(" highestLabelWeight" , labelSummarizer.histogram.max.toString)
208+ instr.logSumOfWeights(summarizer.weightSum)
209+ if ($(blockSize) > 1 ) {
210+ val scale = 1.0 / summarizer.count / numFeatures
211+ val sparsity = 1 - summarizer.numNonzeros.toArray.map(_ * scale).sum
212+ instr.logNamedValue(" sparsity" , sparsity.toString)
213+ if (sparsity > 0.5 ) {
214+ instr.logWarning(s " sparsity of input dataset is $sparsity, " +
215+ s " which may hurt performance in high-level BLAS. " )
216+ }
217+ }
182218
183219 val numClasses = MetadataUtils .getNumClasses(dataset.schema($(labelCol))) match {
184220 case Some (n : Int ) =>
@@ -192,77 +228,113 @@ class LinearSVC @Since("2.2.0") (
192228 instr.logNumClasses(numClasses)
193229 instr.logNumFeatures(numFeatures)
194230
195- val (coefficientVector, interceptVector, objectiveHistory) = {
196- if (numInvalid != 0 ) {
197- val msg = s " Classification labels should be in [0 to ${numClasses - 1 }]. " +
198- s " Found $numInvalid invalid labels. "
199- instr.logError(msg)
200- throw new SparkException (msg)
201- }
231+ if (numInvalid != 0 ) {
232+ val msg = s " Classification labels should be in [0 to ${numClasses - 1 }]. " +
233+ s " Found $numInvalid invalid labels. "
234+ instr.logError(msg)
235+ throw new SparkException (msg)
236+ }
202237
203- val featuresStd = summarizer.std.toArray
204- val getFeaturesStd = (j : Int ) => featuresStd(j)
205- val regParamL2 = $(regParam)
206- val bcFeaturesStd = instances.context.broadcast(featuresStd)
207- val regularization = if (regParamL2 != 0.0 ) {
208- val shouldApply = (idx : Int ) => idx >= 0 && idx < numFeatures
209- Some (new L2Regularization (regParamL2, shouldApply,
210- if ($(standardization)) None else Some (getFeaturesStd)))
211- } else {
212- None
213- }
238+ val featuresStd = summarizer.std.toArray
239+ val getFeaturesStd = (j : Int ) => featuresStd(j)
240+ val regularization = if ($(regParam) != 0.0 ) {
241+ val shouldApply = (idx : Int ) => idx >= 0 && idx < numFeatures
242+ Some (new L2Regularization ($(regParam), shouldApply,
243+ if ($(standardization)) None else Some (getFeaturesStd)))
244+ } else None
245+
246+ def regParamL1Fun = (index : Int ) => 0.0
247+ val optimizer = new BreezeOWLQN [Int , BDV [Double ]]($(maxIter), 10 , regParamL1Fun, $(tol))
248+
249+ /*
250+ The coefficients are trained in the scaled space; we're converting them back to
251+ the original space.
252+ Note that the intercept in scaled space and original space is the same;
253+ as a result, no scaling is needed.
254+ */
255+ val (rawCoefficients, objectiveHistory) = if ($(blockSize) == 1 ) {
256+ trainOnRows(instances, featuresStd, regularization, optimizer)
257+ } else {
258+ trainOnBlocks(instances, featuresStd, regularization, optimizer)
259+ }
260+ if (instances.getStorageLevel != StorageLevel .NONE ) instances.unpersist()
214261
215- val getAggregatorFunc = new HingeAggregator (bcFeaturesStd, $(fitIntercept))(_)
216- val costFun = new RDDLossFunction (instances, getAggregatorFunc, regularization,
217- $(aggregationDepth))
262+ if (rawCoefficients == null ) {
263+ val msg = s " ${optimizer.getClass.getName} failed. "
264+ instr.logError(msg)
265+ throw new SparkException (msg)
266+ }
218267
219- def regParamL1Fun = (index : Int ) => 0D
220- val optimizer = new BreezeOWLQN [Int , BDV [Double ]]($(maxIter), 10 , regParamL1Fun, $(tol))
221- val initialCoefWithIntercept = Vectors .zeros(numFeaturesPlusIntercept)
268+ val coefficientArray = Array .tabulate(numFeatures) { i =>
269+ if (featuresStd(i) != 0.0 ) rawCoefficients(i) / featuresStd(i) else 0.0
270+ }
271+ val intercept = if ($(fitIntercept)) rawCoefficients.last else 0.0
272+ copyValues(new LinearSVCModel (uid, Vectors .dense(coefficientArray), intercept))
273+ }
222274
223- val states = optimizer.iterations(new CachedDiffFunction (costFun),
224- initialCoefWithIntercept.asBreeze.toDenseVector)
275+ private def trainOnRows (
276+ instances : RDD [Instance ],
277+ featuresStd : Array [Double ],
278+ regularization : Option [L2Regularization ],
279+ optimizer : BreezeOWLQN [Int , BDV [Double ]]): (Array [Double ], Array [Double ]) = {
280+ val numFeatures = featuresStd.length
281+ val numFeaturesPlusIntercept = if ($(fitIntercept)) numFeatures + 1 else numFeatures
282+
283+ val bcFeaturesStd = instances.context.broadcast(featuresStd)
284+ val getAggregatorFunc = new HingeAggregator (bcFeaturesStd, $(fitIntercept))(_)
285+ val costFun = new RDDLossFunction (instances, getAggregatorFunc,
286+ regularization, $(aggregationDepth))
287+
288+ val states = optimizer.iterations(new CachedDiffFunction (costFun),
289+ Vectors .zeros(numFeaturesPlusIntercept).asBreeze.toDenseVector)
290+
291+ val arrayBuilder = mutable.ArrayBuilder .make[Double ]
292+ var state : optimizer.State = null
293+ while (states.hasNext) {
294+ state = states.next()
295+ arrayBuilder += state.adjustedValue
296+ }
297+ bcFeaturesStd.destroy()
225298
226- val scaledObjectiveHistory = mutable.ArrayBuilder .make[Double ]
227- var state : optimizer.State = null
228- while (states.hasNext) {
229- state = states.next()
230- scaledObjectiveHistory += state.adjustedValue
231- }
299+ (if (state != null ) state.x.toArray else null , arrayBuilder.result)
300+ }
232301
233- bcFeaturesStd.destroy()
234- if (state == null ) {
235- val msg = s " ${optimizer.getClass.getName} failed. "
236- instr.logError(msg)
237- throw new SparkException (msg)
238- }
302+ private def trainOnBlocks (
303+ instances : RDD [Instance ],
304+ featuresStd : Array [Double ],
305+ regularization : Option [L2Regularization ],
306+ optimizer : BreezeOWLQN [Int , BDV [Double ]]): (Array [Double ], Array [Double ]) = {
307+ val numFeatures = featuresStd.length
308+ val numFeaturesPlusIntercept = if ($(fitIntercept)) numFeatures + 1 else numFeatures
239309
240- /*
241- The coefficients are trained in the scaled space; we're converting them back to
242- the original space.
243- Note that the intercept in scaled space and original space is the same;
244- as a result, no scaling is needed.
245- */
246- val rawCoefficients = state.x.toArray
247- val coefficientArray = Array .tabulate(numFeatures) { i =>
248- if (featuresStd(i) != 0.0 ) {
249- rawCoefficients(i) / featuresStd(i)
250- } else {
251- 0.0
252- }
253- }
310+ val bcFeaturesStd = instances.context.broadcast(featuresStd)
254311
255- val intercept = if ($(fitIntercept)) {
256- rawCoefficients(numFeaturesPlusIntercept - 1 )
257- } else {
258- 0.0
259- }
260- (Vectors .dense(coefficientArray), intercept, scaledObjectiveHistory.result())
312+ val standardized = instances.mapPartitions { iter =>
313+ val inverseStd = bcFeaturesStd.value.map { std => if (std != 0 ) 1.0 / std else 0.0 }
314+ val func = StandardScalerModel .getTransformFunc(Array .empty, inverseStd, false , true )
315+ iter.map { case Instance (label, weight, vec) => Instance (label, weight, func(vec)) }
261316 }
317+ val blocks = InstanceBlock .blokify(standardized, $(blockSize))
318+ .persist(StorageLevel .MEMORY_AND_DISK )
319+ .setName(s " training dataset (blockSize= ${$(blockSize)}) " )
320+
321+ val getAggregatorFunc = new BlockHingeAggregator ($(fitIntercept))(_)
322+ val costFun = new RDDLossFunction (blocks, getAggregatorFunc,
323+ regularization, $(aggregationDepth))
324+
325+ val states = optimizer.iterations(new CachedDiffFunction (costFun),
326+ Vectors .zeros(numFeaturesPlusIntercept).asBreeze.toDenseVector)
327+
328+ val arrayBuilder = mutable.ArrayBuilder .make[Double ]
329+ var state : optimizer.State = null
330+ while (states.hasNext) {
331+ state = states.next()
332+ arrayBuilder += state.adjustedValue
333+ }
334+ blocks.unpersist()
335+ bcFeaturesStd.destroy()
262336
263- if (handlePersistence) instances.unpersist()
264-
265- copyValues(new LinearSVCModel (uid, coefficientVector, interceptVector))
337+ (if (state != null ) state.x.toArray else null , arrayBuilder.result)
266338 }
267339}
268340
0 commit comments