Skip to content

Commit c5a97fc

Browse files
author
Andrew Or
committed
Merge branch 'master' of github.com:apache/spark into fix-input-metrics-coalesce
2 parents c31a410 + b9dfdcc commit c5a97fc

File tree

19 files changed

+590
-76
lines changed

19 files changed

+590
-76
lines changed

common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,15 +115,45 @@ int getVersionNumber() {
115115
public abstract long totalCount();
116116

117117
/**
118-
* Adds 1 to {@code item}.
118+
* Increments {@code item}'s count by one.
119119
*/
120120
public abstract void add(Object item);
121121

122122
/**
123-
* Adds {@code count} to {@code item}.
123+
* Increments {@code item}'s count by {@code count}.
124124
*/
125125
public abstract void add(Object item, long count);
126126

127+
/**
128+
* Increments {@code item}'s count by one.
129+
*/
130+
public abstract void addLong(long item);
131+
132+
/**
133+
* Increments {@code item}'s count by {@code count}.
134+
*/
135+
public abstract void addLong(long item, long count);
136+
137+
/**
138+
* Increments {@code item}'s count by one.
139+
*/
140+
public abstract void addString(String item);
141+
142+
/**
143+
* Increments {@code item}'s count by {@code count}.
144+
*/
145+
public abstract void addString(String item, long count);
146+
147+
/**
148+
* Increments {@code item}'s count by one.
149+
*/
150+
public abstract void addBinary(byte[] item);
151+
152+
/**
153+
* Increments {@code item}'s count by {@code count}.
154+
*/
155+
public abstract void addBinary(byte[] item, long count);
156+
127157
/**
128158
* Returns the estimated frequency of {@code item}.
129159
*/

common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
import java.io.ObjectOutputStream;
2626
import java.io.OutputStream;
2727
import java.io.Serializable;
28-
import java.io.UnsupportedEncodingException;
2928
import java.util.Arrays;
3029
import java.util.Random;
3130

@@ -146,27 +145,49 @@ public void add(Object item, long count) {
146145
}
147146
}
148147

149-
private void addString(String item, long count) {
148+
@Override
149+
public void addString(String item) {
150+
addString(item, 1);
151+
}
152+
153+
@Override
154+
public void addString(String item, long count) {
155+
addBinary(Utils.getBytesFromUTF8String(item), count);
156+
}
157+
158+
@Override
159+
public void addLong(long item) {
160+
addLong(item, 1);
161+
}
162+
163+
@Override
164+
public void addLong(long item, long count) {
150165
if (count < 0) {
151166
throw new IllegalArgumentException("Negative increments not implemented");
152167
}
153168

154-
int[] buckets = getHashBuckets(item, depth, width);
155-
156169
for (int i = 0; i < depth; ++i) {
157-
table[i][buckets[i]] += count;
170+
table[i][hash(item, i)] += count;
158171
}
159172

160173
totalCount += count;
161174
}
162175

163-
private void addLong(long item, long count) {
176+
@Override
177+
public void addBinary(byte[] item) {
178+
addBinary(item, 1);
179+
}
180+
181+
@Override
182+
public void addBinary(byte[] item, long count) {
164183
if (count < 0) {
165184
throw new IllegalArgumentException("Negative increments not implemented");
166185
}
167186

187+
int[] buckets = getHashBuckets(item, depth, width);
188+
168189
for (int i = 0; i < depth; ++i) {
169-
table[i][hash(item, i)] += count;
190+
table[i][buckets[i]] += count;
170191
}
171192

172193
totalCount += count;

docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import java.sql.Connection
2121
import java.util.Properties
2222

2323
import org.apache.spark.sql.Column
24-
import org.apache.spark.sql.catalyst.expressions.{If, Literal}
24+
import org.apache.spark.sql.catalyst.expressions.Literal
2525
import org.apache.spark.tags.DockerTest
2626

2727
@DockerTest
@@ -39,20 +39,21 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite {
3939
override def dataPreparation(conn: Connection): Unit = {
4040
conn.prepareStatement("CREATE DATABASE foo").executeUpdate()
4141
conn.setCatalog("foo")
42+
conn.prepareStatement("CREATE TYPE enum_type AS ENUM ('d1', 'd2')").executeUpdate()
4243
conn.prepareStatement("CREATE TABLE bar (c0 text, c1 integer, c2 double precision, c3 bigint, "
4344
+ "c4 bit(1), c5 bit(10), c6 bytea, c7 boolean, c8 inet, c9 cidr, "
44-
+ "c10 integer[], c11 text[], c12 real[])").executeUpdate()
45+
+ "c10 integer[], c11 text[], c12 real[], c13 enum_type)").executeUpdate()
4546
conn.prepareStatement("INSERT INTO bar VALUES ('hello', 42, 1.25, 123456789012345, B'0', "
4647
+ "B'1000100101', E'\\\\xDEADBEEF', true, '172.16.0.42', '192.168.0.0/16', "
47-
+ """'{1, 2}', '{"a", null, "b"}', '{0.11, 0.22}')""").executeUpdate()
48+
+ """'{1, 2}', '{"a", null, "b"}', '{0.11, 0.22}', 'd1')""").executeUpdate()
4849
}
4950

5051
test("Type mapping for various types") {
5152
val df = sqlContext.read.jdbc(jdbcUrl, "bar", new Properties)
5253
val rows = df.collect()
5354
assert(rows.length == 1)
5455
val types = rows(0).toSeq.map(x => x.getClass)
55-
assert(types.length == 13)
56+
assert(types.length == 14)
5657
assert(classOf[String].isAssignableFrom(types(0)))
5758
assert(classOf[java.lang.Integer].isAssignableFrom(types(1)))
5859
assert(classOf[java.lang.Double].isAssignableFrom(types(2)))
@@ -66,22 +67,24 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite {
6667
assert(classOf[Seq[Int]].isAssignableFrom(types(10)))
6768
assert(classOf[Seq[String]].isAssignableFrom(types(11)))
6869
assert(classOf[Seq[Double]].isAssignableFrom(types(12)))
70+
assert(classOf[String].isAssignableFrom(types(13)))
6971
assert(rows(0).getString(0).equals("hello"))
7072
assert(rows(0).getInt(1) == 42)
7173
assert(rows(0).getDouble(2) == 1.25)
7274
assert(rows(0).getLong(3) == 123456789012345L)
73-
assert(rows(0).getBoolean(4) == false)
75+
assert(!rows(0).getBoolean(4))
7476
// BIT(10)'s come back as ASCII strings of ten ASCII 0's and 1's...
7577
assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](5),
7678
Array[Byte](49, 48, 48, 48, 49, 48, 48, 49, 48, 49)))
7779
assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](6),
7880
Array[Byte](0xDE.toByte, 0xAD.toByte, 0xBE.toByte, 0xEF.toByte)))
79-
assert(rows(0).getBoolean(7) == true)
81+
assert(rows(0).getBoolean(7))
8082
assert(rows(0).getString(8) == "172.16.0.42")
8183
assert(rows(0).getString(9) == "192.168.0.0/16")
8284
assert(rows(0).getSeq(10) == Seq(1, 2))
8385
assert(rows(0).getSeq(11) == Seq("a", null, "b"))
8486
assert(rows(0).getSeq(12).toSeq == Seq(0.11f, 0.22f))
87+
assert(rows(0).getString(13) == "d1")
8588
}
8689

8790
test("Basic write test") {
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml.optim
19+
20+
import org.apache.spark.Logging
21+
import org.apache.spark.ml.feature.Instance
22+
import org.apache.spark.mllib.linalg._
23+
import org.apache.spark.rdd.RDD
24+
25+
/**
26+
* Model fitted by [[IterativelyReweightedLeastSquares]].
27+
* @param coefficients model coefficients
28+
* @param intercept model intercept
29+
*/
30+
private[ml] class IterativelyReweightedLeastSquaresModel(
31+
val coefficients: DenseVector,
32+
val intercept: Double) extends Serializable
33+
34+
/**
35+
* Implements the method of iteratively reweighted least squares (IRLS) which is used to solve
36+
* certain optimization problems by an iterative method. In each step of the iterations, it
37+
* involves solving a weighted lease squares (WLS) problem by [[WeightedLeastSquares]].
38+
* It can be used to find maximum likelihood estimates of a generalized linear model (GLM),
39+
* find M-estimator in robust regression and other optimization problems.
40+
*
41+
* @param initialModel the initial guess model.
42+
* @param reweightFunc the reweight function which is used to update offsets and weights
43+
* at each iteration.
44+
* @param fitIntercept whether to fit intercept.
45+
* @param regParam L2 regularization parameter used by WLS.
46+
* @param maxIter maximum number of iterations.
47+
* @param tol the convergence tolerance.
48+
*
49+
* @see [[http://www.jstor.org/stable/2345503 P. J. Green, Iteratively Reweighted Least Squares
50+
* for Maximum Likelihood Estimation, and some Robust and Resistant Alternatives,
51+
* Journal of the Royal Statistical Society. Series B, 1984.]]
52+
*/
53+
private[ml] class IterativelyReweightedLeastSquares(
54+
val initialModel: WeightedLeastSquaresModel,
55+
val reweightFunc: (Instance, WeightedLeastSquaresModel) => (Double, Double),
56+
val fitIntercept: Boolean,
57+
val regParam: Double,
58+
val maxIter: Int,
59+
val tol: Double) extends Logging with Serializable {
60+
61+
def fit(instances: RDD[Instance]): IterativelyReweightedLeastSquaresModel = {
62+
63+
var converged = false
64+
var iter = 0
65+
66+
var model: WeightedLeastSquaresModel = initialModel
67+
var oldModel: WeightedLeastSquaresModel = null
68+
69+
while (iter < maxIter && !converged) {
70+
71+
oldModel = model
72+
73+
// Update offsets and weights using reweightFunc
74+
val newInstances = instances.map { instance =>
75+
val (newOffset, newWeight) = reweightFunc(instance, oldModel)
76+
Instance(newOffset, newWeight, instance.features)
77+
}
78+
79+
// Estimate new model
80+
model = new WeightedLeastSquares(fitIntercept, regParam, standardizeFeatures = false,
81+
standardizeLabel = false).fit(newInstances)
82+
83+
// Check convergence
84+
val oldCoefficients = oldModel.coefficients
85+
val coefficients = model.coefficients
86+
BLAS.axpy(-1.0, coefficients, oldCoefficients)
87+
val maxTolOfCoefficients = oldCoefficients.toArray.reduce { (x, y) =>
88+
math.max(math.abs(x), math.abs(y))
89+
}
90+
val maxTol = math.max(maxTolOfCoefficients, math.abs(oldModel.intercept - model.intercept))
91+
92+
if (maxTol < tol) {
93+
converged = true
94+
logInfo(s"IRLS converged in $iter iterations.")
95+
}
96+
97+
logInfo(s"Iteration $iter : relative tolerance = $maxTol")
98+
iter = iter + 1
99+
100+
if (iter == maxIter) {
101+
logInfo(s"IRLS reached the max number of iterations: $maxIter.")
102+
}
103+
104+
}
105+
106+
new IterativelyReweightedLeastSquaresModel(model.coefficients, model.intercept)
107+
}
108+
}

mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,12 @@ import org.apache.spark.rdd.RDD
3131
private[ml] class WeightedLeastSquaresModel(
3232
val coefficients: DenseVector,
3333
val intercept: Double,
34-
val diagInvAtWA: DenseVector) extends Serializable
34+
val diagInvAtWA: DenseVector) extends Serializable {
35+
36+
def predict(features: Vector): Double = {
37+
BLAS.dot(coefficients, features) + intercept
38+
}
39+
}
3540

3641
/**
3742
* Weighted least squares solver via normal equation.

0 commit comments

Comments
 (0)