Skip to content

Commit 9af2e95

Browse files
committed
refine the code style
1 parent ad6c82d commit 9af2e95

File tree

2 files changed

+39
-25
lines changed

2 files changed

+39
-25
lines changed

mllib/src/main/scala/org/apache/spark/mllib/rdd/VectorRDDFunctions.scala

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,9 @@ import org.apache.spark.mllib.util.MLUtils._
2323
import org.apache.spark.rdd.RDD
2424

2525
/**
26-
* Extra functions available on RDDs of [[org.apache.spark.mllib.linalg.Vector Vector]] through an implicit conversion.
27-
* Import `org.apache.spark.MLContext._` at the top of your program to use these functions.
26+
* Extra functions available on RDDs of [[org.apache.spark.mllib.linalg.Vector Vector]] through an
27+
* implicit conversion. Import `org.apache.spark.MLContext._` at the top of your program to use
28+
* these functions.
2829
*/
2930
class VectorRDDFunctions(self: RDD[Vector]) extends Serializable {
3031

@@ -81,31 +82,36 @@ class VectorRDDFunctions(self: RDD[Vector]) extends Serializable {
8182
/**
8283
* Compute the norm-2 of each column in the RDD with `size` as the dimension of each `Vector`.
8384
*/
84-
def colNorm2(size: Int): Vector = Vectors.fromBreeze(self.map(_.toBreeze).aggregate(BV.zeros[Double](size))(
85-
seqOp = (c, v) => c + (v :* v),
86-
combOp = (lhs, rhs) => lhs + rhs
87-
).map(math.sqrt))
85+
def colNorm2(size: Int): Vector = Vectors.fromBreeze(self.map(_.toBreeze)
86+
.aggregate(BV.zeros[Double](size))(
87+
seqOp = (c, v) => c + (v :* v),
88+
combOp = (lhs, rhs) => lhs + rhs
89+
).map(math.sqrt)
90+
)
8891

8992
/**
9093
* Compute the standard deviation of each column in the RDD.
9194
*/
9295
def colSDs(): Vector = colSDs(self.take(1).head.size)
9396

9497
/**
95-
* Compute the standard deviation of each column in the RDD with `size` as the dimension of each `Vector`.
98+
* Compute the standard deviation of each column in the RDD with `size` as the dimension of each
99+
* `Vector`.
96100
*/
97101
def colSDs(size: Int): Vector = {
98102
val means = self.colMeans()
99-
Vectors.fromBreeze(self.map(x => x.toBreeze - means.toBreeze).aggregate((BV.zeros[Double](size), 0.0))(
100-
seqOp = (c, v) => (c, v) match {
101-
case ((prev, cnt), current) =>
102-
(((prev :* cnt) + (current :* current)) :/ (cnt + 1.0), cnt + 1.0)
103-
},
104-
combOp = (lhs, rhs) => (lhs, rhs) match {
105-
case ((lhsVec, lhsCnt), (rhsVec, rhsCnt)) =>
106-
((lhsVec :* lhsCnt) + (rhsVec :* rhsCnt) :/ (lhsCnt + rhsCnt), lhsCnt + rhsCnt)
107-
}
108-
)._1.map(math.sqrt))
103+
Vectors.fromBreeze(self.map(x => x.toBreeze - means.toBreeze)
104+
.aggregate((BV.zeros[Double](size), 0.0))(
105+
seqOp = (c, v) => (c, v) match {
106+
case ((prev, cnt), current) =>
107+
(((prev :* cnt) + (current :* current)) :/ (cnt + 1.0), cnt + 1.0)
108+
},
109+
combOp = (lhs, rhs) => (lhs, rhs) match {
110+
case ((lhsVec, lhsCnt), (rhsVec, rhsCnt)) =>
111+
((lhsVec :* lhsCnt) + (rhsVec :* rhsCnt) :/ (lhsCnt + rhsCnt), lhsCnt + rhsCnt)
112+
}
113+
)._1.map(math.sqrt)
114+
)
109115
}
110116

111117
/**
@@ -119,12 +125,14 @@ class VectorRDDFunctions(self: RDD[Vector]) extends Serializable {
119125
}
120126

121127
/**
122-
* Find the optional max vector in the RDD, `None` will be returned if there is no elements at all.
128+
* Find the optional max vector in the RDD, `None` will be returned if there is no elements at
129+
* all.
123130
*/
124131
def maxOption(cmp: (Vector, Vector) => Boolean) = maxMinOption(cmp)
125132

126133
/**
127-
* Find the optional min vector in the RDD, `None` will be returned if there is no elements at all.
134+
* Find the optional min vector in the RDD, `None` will be returned if there is no elements at
135+
* all.
128136
*/
129137
def minOption(cmp: (Vector, Vector) => Boolean) = maxMinOption(!cmp(_, _))
130138

mllib/src/test/scala/org/apache/spark/mllib/rdd/VectorRDDFunctionsSuite.scala

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,32 +61,38 @@ class VectorRDDFunctionsSuite extends FunSuite with LocalSparkContext {
6161

6262
test("rowMeans") {
6363
val data = sc.parallelize(localData, 2)
64-
assert(equivVector(Vectors.dense(data.rowMeans().collect()), Vectors.dense(rowMeans)), "Row means do not match.")
64+
assert(equivVector(Vectors.dense(data.rowMeans().collect()), Vectors.dense(rowMeans)),
65+
"Row means do not match.")
6566
}
6667

6768
test("rowNorm2") {
6869
val data = sc.parallelize(localData, 2)
69-
assert(equivVector(Vectors.dense(data.rowNorm2().collect()), Vectors.dense(rowNorm2)), "Row norm2s do not match.")
70+
assert(equivVector(Vectors.dense(data.rowNorm2().collect()), Vectors.dense(rowNorm2)),
71+
"Row norm2s do not match.")
7072
}
7173

7274
test("rowSDs") {
7375
val data = sc.parallelize(localData, 2)
74-
assert(equivVector(Vectors.dense(data.rowSDs().collect()), Vectors.dense(rowSDs)), "Row SDs do not match.")
76+
assert(equivVector(Vectors.dense(data.rowSDs().collect()), Vectors.dense(rowSDs)),
77+
"Row SDs do not match.")
7578
}
7679

7780
test("colMeans") {
7881
val data = sc.parallelize(localData, 2)
79-
assert(equivVector(data.colMeans(), Vectors.dense(colMeans)), "Column means do not match.")
82+
assert(equivVector(data.colMeans(), Vectors.dense(colMeans)),
83+
"Column means do not match.")
8084
}
8185

8286
test("colNorm2") {
8387
val data = sc.parallelize(localData, 2)
84-
assert(equivVector(data.colNorm2(), Vectors.dense(colNorm2)), "Column norm2s do not match.")
88+
assert(equivVector(data.colNorm2(), Vectors.dense(colNorm2)),
89+
"Column norm2s do not match.")
8590
}
8691

8792
test("colSDs") {
8893
val data = sc.parallelize(localData, 2)
89-
assert(equivVector(data.colSDs(), Vectors.dense(colSDs)), "Column SDs do not match.")
94+
assert(equivVector(data.colSDs(), Vectors.dense(colSDs)),
95+
"Column SDs do not match.")
9096
}
9197

9298
test("maxOption") {

0 commit comments

Comments
 (0)