Skip to content

Commit 5e492e9

Browse files
ericlmengxr
authored andcommitted
[SPARK-12346][ML] Missing attribute names in GLM for vector-type features
Currently `summary()` fails on a GLM model fitted over a vector feature missing ML attrs, since the output feature attrs will also have no name. We can avoid this situation by forcing `VectorAssembler` to make up suitable names when inputs are missing names. cc mengxr Author: Eric Liang <[email protected]> Closes #10323 from ericl/spark-12346.
1 parent 44fcf99 commit 5e492e9

File tree

3 files changed

+43
-5
lines changed

3 files changed

+43
-5
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,19 +70,19 @@ class VectorAssembler(override val uid: String)
7070
val group = AttributeGroup.fromStructField(field)
7171
if (group.attributes.isDefined) {
7272
// If attributes are defined, copy them with updated names.
73-
group.attributes.get.map { attr =>
73+
group.attributes.get.zipWithIndex.map { case (attr, i) =>
7474
if (attr.name.isDefined) {
7575
// TODO: Define a rigorous naming scheme.
7676
attr.withName(c + "_" + attr.name.get)
7777
} else {
78-
attr
78+
attr.withName(c + "_" + i)
7979
}
8080
}
8181
} else {
8282
// Otherwise, treat all attributes as numeric. If we cannot get the number of attributes
8383
// from metadata, check the first row.
8484
val numAttrs = group.numAttributes.getOrElse(first.getAs[Vector](index).size)
85-
Array.fill(numAttrs)(NumericAttribute.defaultAttr)
85+
Array.tabulate(numAttrs)(i => NumericAttribute.defaultAttr.withName(c + "_" + i))
8686
}
8787
case otherType =>
8888
throw new SparkException(s"VectorAssembler does not support the $otherType type")

mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,44 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext {
143143
assert(attrs === expectedAttrs)
144144
}
145145

146+
test("vector attribute generation") {
147+
val formula = new RFormula().setFormula("id ~ vec")
148+
val original = sqlContext.createDataFrame(
149+
Seq((1, Vectors.dense(0.0, 1.0)), (2, Vectors.dense(1.0, 2.0)))
150+
).toDF("id", "vec")
151+
val model = formula.fit(original)
152+
val result = model.transform(original)
153+
val attrs = AttributeGroup.fromStructField(result.schema("features"))
154+
val expectedAttrs = new AttributeGroup(
155+
"features",
156+
Array[Attribute](
157+
new NumericAttribute(Some("vec_0"), Some(1)),
158+
new NumericAttribute(Some("vec_1"), Some(2))))
159+
assert(attrs === expectedAttrs)
160+
}
161+
162+
test("vector attribute generation with unnamed input attrs") {
163+
val formula = new RFormula().setFormula("id ~ vec2")
164+
val base = sqlContext.createDataFrame(
165+
Seq((1, Vectors.dense(0.0, 1.0)), (2, Vectors.dense(1.0, 2.0)))
166+
).toDF("id", "vec")
167+
val metadata = new AttributeGroup(
168+
"vec2",
169+
Array[Attribute](
170+
NumericAttribute.defaultAttr,
171+
NumericAttribute.defaultAttr)).toMetadata
172+
val original = base.select(base.col("id"), base.col("vec").as("vec2", metadata))
173+
val model = formula.fit(original)
174+
val result = model.transform(original)
175+
val attrs = AttributeGroup.fromStructField(result.schema("features"))
176+
val expectedAttrs = new AttributeGroup(
177+
"features",
178+
Array[Attribute](
179+
new NumericAttribute(Some("vec2_0"), Some(1)),
180+
new NumericAttribute(Some("vec2_1"), Some(2))))
181+
assert(attrs === expectedAttrs)
182+
}
183+
146184
test("numeric interaction") {
147185
val formula = new RFormula().setFormula("a ~ b:c:d")
148186
val original = sqlContext.createDataFrame(

mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,8 @@ class VectorAssemblerSuite
111111
assert(userGenderOut === user.getAttr("gender").withName("user_gender").withIndex(3))
112112
val userSalaryOut = features.getAttr(4)
113113
assert(userSalaryOut === user.getAttr("salary").withName("user_salary").withIndex(4))
114-
assert(features.getAttr(5) === NumericAttribute.defaultAttr.withIndex(5))
115-
assert(features.getAttr(6) === NumericAttribute.defaultAttr.withIndex(6))
114+
assert(features.getAttr(5) === NumericAttribute.defaultAttr.withIndex(5).withName("ad_0"))
115+
assert(features.getAttr(6) === NumericAttribute.defaultAttr.withIndex(6).withName("ad_1"))
116116
}
117117

118118
test("read/write") {

0 commit comments

Comments
 (0)