Skip to content

Commit a067da2

Browse files
committed
a new approach for poly expansion
1 parent 0789d81 commit a067da2

File tree

1 file changed

+117
-4
lines changed

1 file changed

+117
-4
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialMapper.scala

Lines changed: 117 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.ml.feature
1919

2020
import scala.annotation.tailrec
21+
import scala.collection.mutable
2122
import scala.collection.mutable.ArrayBuffer
2223

2324
import org.apache.spark.annotation.AlphaComponent
@@ -41,16 +42,19 @@ class PolynomialMapper extends UnaryTransformer[Vector, Vector, PolynomialMapper
4142
* The polynomial degree to expand, which should be larger than 1.
4243
* @group param
4344
*/
44-
val degree = new IntParam(this, "degree", "the polynomial degree to expand", Some(2))
45+
val degree = new IntParam(this, "degree", "the polynomial degree to expand")
46+
setDefault(degree -> 2)
4547

4648
/** @group getParam */
47-
def getDegree: Int = get(degree)
49+
def getDegree: Int = getOrDefault(degree)
4850

4951
/** @group setParam */
5052
def setDegree(value: Int): this.type = set(degree, value)
5153

52-
override protected def createTransformFunc(paramMap: ParamMap): Vector => Vector = {
53-
PolynomialMapper.transform(getDegree)
54+
override protected def createTransformFunc(paramMap: ParamMap): Vector => Vector = { v =>
55+
val d = paramMap(degree)
56+
// PolynomialMapper.transform(getDegree)
57+
PolynomialMapperV2.expand(v, d)
5458
}
5559

5660
override protected def outputDataType: DataType = new VectorUDT()
@@ -238,3 +242,112 @@ object PolynomialMapper {
238242
}
239243
}
240244
}
245+
246+
/**
247+
* The expansion is done via recursion. Given n features and degree d, the size after expansion is
248+
* (n + d choose d) (including 1 and first-order values). For example, let f([a, b, c], 3) be the
249+
* function that expands [a, b, c] to their monomials of degree 3. We have the following recursion:
250+
*
251+
* {{{
252+
* f([a, b, c], 3) = f([a, b], 3) ++ f([a, b], 2) * c ++ f([a, b], 1) * c^2 ++ [c^3]
253+
* }}}
254+
*
255+
* To handle sparsity, if c is zero, we can skip all monomials that contain it. We remember the
256+
* current index and increment it properly for sparse input.
257+
*/
258+
object PolynomialMapperV2 {
259+
260+
private def choose(n: Int, k: Int): Int = {
261+
Range(n, n - k, -1).product / Range(k, 1, -1).product
262+
}
263+
264+
private def getPolySize(numFeatures: Int, degree: Int): Int = choose(numFeatures + degree, degree)
265+
266+
private def expandDense(
267+
values: Array[Double],
268+
lastIdx: Int,
269+
degree: Int,
270+
multiplier: Double,
271+
polyValues: Array[Double],
272+
curPolyIdx: Int): Int = {
273+
if (multiplier == 0.0) {
274+
// do nothing
275+
} else if (degree == 0 || lastIdx < 0) {
276+
polyValues(curPolyIdx) = multiplier
277+
} else {
278+
val v = values(lastIdx)
279+
val lastIdx1 = lastIdx - 1
280+
var alpha = multiplier
281+
var i = 0
282+
var curStart = curPolyIdx
283+
while (i <= degree && alpha != 0.0) {
284+
curStart = expandDense(values, lastIdx1, degree - i, alpha, polyValues, curStart)
285+
i += 1
286+
alpha *= v
287+
}
288+
}
289+
curPolyIdx + getPolySize(lastIdx + 1, degree)
290+
}
291+
292+
private def expandSparse(
293+
indices: Array[Int],
294+
values: Array[Double],
295+
lastIdx: Int,
296+
lastFeatureIdx: Int,
297+
degree: Int,
298+
multiplier: Double,
299+
polyIndices: mutable.ArrayBuilder[Int],
300+
polyValues: mutable.ArrayBuilder[Double],
301+
curPolyIdx: Int): Int = {
302+
if (multiplier == 0.0) {
303+
// do nothing
304+
} else if (degree == 0 || lastIdx < 0) {
305+
polyIndices += curPolyIdx
306+
polyValues += multiplier
307+
} else {
308+
// Skip all zeros at the tail.
309+
val v = values(lastIdx)
310+
val lastIdx1 = lastIdx - 1
311+
val lastFeatureIdx1 = indices(lastIdx) - 1
312+
var alpha = multiplier
313+
var curStart = curPolyIdx
314+
var i = 0
315+
while (i <= degree && alpha != 0.0) {
316+
curStart = expandSparse(indices, values, lastIdx1, lastFeatureIdx1, degree - i, alpha,
317+
polyIndices, polyValues, curStart)
318+
i += 1
319+
alpha *= v
320+
}
321+
}
322+
curPolyIdx + getPolySize(lastFeatureIdx + 1, degree)
323+
}
324+
325+
private def expand(dv: DenseVector, degree: Int): DenseVector = {
326+
val n = dv.size
327+
val polySize = getPolySize(n, degree)
328+
val polyValues = new Array[Double](polySize)
329+
expandDense(dv.values, n - 1, degree, 1.0, polyValues, 0)
330+
new DenseVector(polyValues)
331+
}
332+
333+
private def expand(sv: SparseVector, degree: Int): SparseVector = {
334+
val polySize = getPolySize(sv.size, degree)
335+
val nnz = sv.values.length
336+
val nnzPolySize = getPolySize(nnz, degree)
337+
val polyIndices = mutable.ArrayBuilder.make[Int]
338+
polyIndices.sizeHint(nnzPolySize)
339+
val polyValues = mutable.ArrayBuilder.make[Double]
340+
polyValues.sizeHint(nnzPolySize)
341+
expandSparse(
342+
sv.indices, sv.values, nnz - 1, sv.size - 1, degree, 1.0, polyIndices, polyValues, 0)
343+
new SparseVector(polySize, polyIndices.result(), polyValues.result())
344+
}
345+
346+
def expand(v: Vector, degree: Int): Vector = {
347+
v match {
348+
case dv: DenseVector => expand(dv, degree)
349+
case sv: SparseVector => expand(sv, degree)
350+
case _ => throw new IllegalArgumentException
351+
}
352+
}
353+
}

0 commit comments

Comments
 (0)