Skip to content

Commit 2ddeb63

Browse files
Lewuathemengxr
authored andcommitted
[SPARK-10117] [MLLIB] Implement SQL data source API for reading LIBSVM data
It is convenient to implement data source API for LIBSVM format to have a better integration with DataFrames and ML pipeline API. Two option is implemented. * `numFeatures`: Specify the dimension of features vector * `featuresType`: Specify the type of output vector. `sparse` is default. Author: lewuathe <[email protected]> Closes #8537 from Lewuathe/SPARK-10117 and squashes the following commits: 986999d [lewuathe] Change unit test phrase 11d513f [lewuathe] Fix some reviews 21600a4 [lewuathe] Merge branch 'master' into SPARK-10117 9ce63c7 [lewuathe] Rewrite service loader file 1fdd2df [lewuathe] Merge branch 'SPARK-10117' of github.com:Lewuathe/spark into SPARK-10117 ba3657c [lewuathe] Merge branch 'master' into SPARK-10117 0ea1c1c [lewuathe] LibSVMRelation is registered into META-INF 4f40891 [lewuathe] Improve test suites 5ab62ab [lewuathe] Merge branch 'master' into SPARK-10117 8660d0e [lewuathe] Fix Java unit test b56a948 [lewuathe] Merge branch 'master' into SPARK-10117 2c12894 [lewuathe] Remove unnecessary tag 7d693c2 [lewuathe] Resolv conflict 62010af [lewuathe] Merge branch 'master' into SPARK-10117 a97ee97 [lewuathe] Fix some points aef9564 [lewuathe] Fix 70ee4dd [lewuathe] Add Java test 3fd8dce [lewuathe] [SPARK-10117] Implement SQL data source API for reading LIBSVM data 40d3027 [lewuathe] Add Java test 7056d4a [lewuathe] Merge branch 'master' into SPARK-10117 99accaa [lewuathe] [SPARK-10117] Implement SQL data source API for reading LIBSVM data
1 parent c1bc4f4 commit 2ddeb63

File tree

4 files changed

+256
-0
lines changed

4 files changed

+256
-0
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
org.apache.spark.ml.source.libsvm.DefaultSource
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml.source.libsvm
19+
20+
import com.google.common.base.Objects
21+
22+
import org.apache.spark.Logging
23+
import org.apache.spark.annotation.Since
24+
import org.apache.spark.mllib.linalg.VectorUDT
25+
import org.apache.spark.mllib.util.MLUtils
26+
import org.apache.spark.rdd.RDD
27+
import org.apache.spark.sql.types.{StructType, StructField, DoubleType}
28+
import org.apache.spark.sql.{Row, SQLContext}
29+
import org.apache.spark.sql.sources._
30+
31+
/**
32+
* LibSVMRelation provides the DataFrame constructed from LibSVM format data.
33+
* @param path File path of LibSVM format
34+
* @param numFeatures The number of features
35+
* @param vectorType The type of vector. It can be 'sparse' or 'dense'
36+
* @param sqlContext The Spark SQLContext
37+
*/
38+
private[ml] class LibSVMRelation(val path: String, val numFeatures: Int, val vectorType: String)
39+
(@transient val sqlContext: SQLContext)
40+
extends BaseRelation with TableScan with Logging with Serializable {
41+
42+
override def schema: StructType = StructType(
43+
StructField("label", DoubleType, nullable = false) ::
44+
StructField("features", new VectorUDT(), nullable = false) :: Nil
45+
)
46+
47+
override def buildScan(): RDD[Row] = {
48+
val sc = sqlContext.sparkContext
49+
val baseRdd = MLUtils.loadLibSVMFile(sc, path, numFeatures)
50+
51+
baseRdd.map { pt =>
52+
val features = if (vectorType == "dense") pt.features.toDense else pt.features.toSparse
53+
Row(pt.label, features)
54+
}
55+
}
56+
57+
override def hashCode(): Int = {
58+
Objects.hashCode(path, schema)
59+
}
60+
61+
override def equals(other: Any): Boolean = other match {
62+
case that: LibSVMRelation => (this.path == that.path) && this.schema.equals(that.schema)
63+
case _ => false
64+
}
65+
66+
}
67+
68+
/**
69+
* This is used for creating DataFrame from LibSVM format file.
70+
* The LibSVM file path must be specified to DefaultSource.
71+
*/
72+
@Since("1.6.0")
73+
class DefaultSource extends RelationProvider with DataSourceRegister {
74+
75+
@Since("1.6.0")
76+
override def shortName(): String = "libsvm"
77+
78+
private def checkPath(parameters: Map[String, String]): String = {
79+
require(parameters.contains("path"), "'path' must be specified")
80+
parameters.get("path").get
81+
}
82+
83+
/**
84+
* Returns a new base relation with the given parameters.
85+
* Note: the parameters' keywords are case insensitive and this insensitivity is enforced
86+
* by the Map that is passed to the function.
87+
*/
88+
override def createRelation(sqlContext: SQLContext, parameters: Map[String, String])
89+
: BaseRelation = {
90+
val path = checkPath(parameters)
91+
val numFeatures = parameters.getOrElse("numFeatures", "-1").toInt
92+
/**
93+
* featuresType can be selected "dense" or "sparse".
94+
* This parameter decides the type of returned feature vector.
95+
*/
96+
val vectorType = parameters.getOrElse("vectorType", "sparse")
97+
new LibSVMRelation(path, numFeatures, vectorType)(sqlContext)
98+
}
99+
}
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml.source;
19+
20+
import java.io.File;
21+
import java.io.IOException;
22+
23+
import com.google.common.base.Charsets;
24+
import com.google.common.io.Files;
25+
26+
import org.junit.After;
27+
import org.junit.Assert;
28+
import org.junit.Before;
29+
import org.junit.Test;
30+
31+
import org.apache.spark.api.java.JavaSparkContext;
32+
import org.apache.spark.mllib.linalg.DenseVector;
33+
import org.apache.spark.mllib.linalg.Vectors;
34+
import org.apache.spark.sql.DataFrame;
35+
import org.apache.spark.sql.Row;
36+
import org.apache.spark.sql.SQLContext;
37+
import org.apache.spark.util.Utils;
38+
39+
40+
/**
41+
* Test LibSVMRelation in Java.
42+
*/
43+
public class JavaLibSVMRelationSuite {
44+
private transient JavaSparkContext jsc;
45+
private transient SQLContext jsql;
46+
private transient DataFrame dataset;
47+
48+
private File tmpDir;
49+
private File path;
50+
51+
@Before
52+
public void setUp() throws IOException {
53+
jsc = new JavaSparkContext("local", "JavaLibSVMRelationSuite");
54+
jsql = new SQLContext(jsc);
55+
56+
tmpDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource");
57+
path = new File(tmpDir.getPath(), "part-00000");
58+
59+
String s = "1 1:1.0 3:2.0 5:3.0\n0\n0 2:4.0 4:5.0 6:6.0";
60+
Files.write(s, path, Charsets.US_ASCII);
61+
}
62+
63+
@After
64+
public void tearDown() {
65+
jsc.stop();
66+
jsc = null;
67+
Utils.deleteRecursively(tmpDir);
68+
}
69+
70+
@Test
71+
public void verifyLibSVMDF() {
72+
dataset = jsql.read().format("libsvm").option("vectorType", "dense").load(path.getPath());
73+
Assert.assertEquals("label", dataset.columns()[0]);
74+
Assert.assertEquals("features", dataset.columns()[1]);
75+
Row r = dataset.first();
76+
Assert.assertEquals(1.0, r.getDouble(0), 1e-15);
77+
DenseVector v = r.getAs(1);
78+
Assert.assertEquals(Vectors.dense(1.0, 0.0, 2.0, 0.0, 3.0, 0.0), v);
79+
}
80+
}
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml.source
19+
20+
import java.io.File
21+
22+
import com.google.common.base.Charsets
23+
import com.google.common.io.Files
24+
25+
import org.apache.spark.SparkFunSuite
26+
import org.apache.spark.mllib.linalg.{SparseVector, Vectors, DenseVector}
27+
import org.apache.spark.mllib.util.MLlibTestSparkContext
28+
import org.apache.spark.util.Utils
29+
30+
class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext {
31+
var path: String = _
32+
33+
override def beforeAll(): Unit = {
34+
super.beforeAll()
35+
val lines =
36+
"""
37+
|1 1:1.0 3:2.0 5:3.0
38+
|0
39+
|0 2:4.0 4:5.0 6:6.0
40+
""".stripMargin
41+
val tempDir = Utils.createTempDir()
42+
val file = new File(tempDir.getPath, "part-00000")
43+
Files.write(lines, file, Charsets.US_ASCII)
44+
path = tempDir.toURI.toString
45+
}
46+
47+
test("select as sparse vector") {
48+
val df = sqlContext.read.format("libsvm").load(path)
49+
assert(df.columns(0) == "label")
50+
assert(df.columns(1) == "features")
51+
val row1 = df.first()
52+
assert(row1.getDouble(0) == 1.0)
53+
val v = row1.getAs[SparseVector](1)
54+
assert(v == Vectors.sparse(6, Seq((0, 1.0), (2, 2.0), (4, 3.0))))
55+
}
56+
57+
test("select as dense vector") {
58+
val df = sqlContext.read.format("libsvm").options(Map("vectorType" -> "dense"))
59+
.load(path)
60+
assert(df.columns(0) == "label")
61+
assert(df.columns(1) == "features")
62+
assert(df.count() == 3)
63+
val row1 = df.first()
64+
assert(row1.getDouble(0) == 1.0)
65+
val v = row1.getAs[DenseVector](1)
66+
assert(v == Vectors.dense(1.0, 0.0, 2.0, 0.0, 3.0, 0.0))
67+
}
68+
69+
test("select a vector with specifying the longer dimension") {
70+
val df = sqlContext.read.option("numFeatures", "100").format("libsvm")
71+
.load(path)
72+
val row1 = df.first()
73+
val v = row1.getAs[SparseVector](1)
74+
assert(v == Vectors.sparse(100, Seq((0, 1.0), (2, 2.0), (4, 3.0))))
75+
}
76+
}

0 commit comments

Comments
 (0)