Skip to content

Commit 960d2d0

Browse files
committed
[SPARK-10537] [ML] document LIBSVM source options in public API doc and some minor improvements
We should document options in public API doc. Otherwise, it is hard to find out the options without looking at the code. I tried to make `DefaultSource` private and put the documentation to package doc. However, since then there exists no public class under `source.libsvm`, the Java package doc doesn't show up in the generated html file (http://bugs.java.com/bugdatabase/view_bug.do?bug_id=4492654). So I put the doc to `DefaultSource` instead. There are several minor updates in this PR: 1. Do `vectorType == "sparse"` only once. 2. Update `hashCode` and `equals`. 3. Remove inherited doc. 4. Delete temp dir in `afterAll`. Lewuathe Author: Xiangrui Meng <[email protected]> Closes #8699 from mengxr/SPARK-10537.
1 parent b01b262 commit 960d2d0

File tree

3 files changed

+66
-43
lines changed

3 files changed

+66
-43
lines changed

mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala

Lines changed: 44 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,12 @@ import com.google.common.base.Objects
2121

2222
import org.apache.spark.Logging
2323
import org.apache.spark.annotation.Since
24-
import org.apache.spark.mllib.linalg.VectorUDT
24+
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
2525
import org.apache.spark.mllib.util.MLUtils
2626
import org.apache.spark.rdd.RDD
27-
import org.apache.spark.sql.types.{StructType, StructField, DoubleType}
28-
import org.apache.spark.sql.{Row, SQLContext}
27+
import org.apache.spark.sql.{DataFrameReader, DataFrame, Row, SQLContext}
2928
import org.apache.spark.sql.sources._
29+
import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
3030

3131
/**
3232
* LibSVMRelation provides the DataFrame constructed from LibSVM format data.
@@ -35,7 +35,7 @@ import org.apache.spark.sql.sources._
3535
* @param vectorType The type of vector. It can be 'sparse' or 'dense'
3636
* @param sqlContext The Spark SQLContext
3737
*/
38-
private[ml] class LibSVMRelation(val path: String, val numFeatures: Int, val vectorType: String)
38+
private[libsvm] class LibSVMRelation(val path: String, val numFeatures: Int, val vectorType: String)
3939
(@transient val sqlContext: SQLContext)
4040
extends BaseRelation with TableScan with Logging with Serializable {
4141

@@ -47,52 +47,69 @@ private[ml] class LibSVMRelation(val path: String, val numFeatures: Int, val vec
4747
override def buildScan(): RDD[Row] = {
4848
val sc = sqlContext.sparkContext
4949
val baseRdd = MLUtils.loadLibSVMFile(sc, path, numFeatures)
50-
50+
val sparse = vectorType == "sparse"
5151
baseRdd.map { pt =>
52-
val features = if (vectorType == "dense") pt.features.toDense else pt.features.toSparse
52+
val features = if (sparse) pt.features.toSparse else pt.features.toDense
5353
Row(pt.label, features)
5454
}
5555
}
5656

5757
override def hashCode(): Int = {
58-
Objects.hashCode(path, schema)
58+
Objects.hashCode(path, Double.box(numFeatures), vectorType)
5959
}
6060

6161
override def equals(other: Any): Boolean = other match {
62-
case that: LibSVMRelation => (this.path == that.path) && this.schema.equals(that.schema)
63-
case _ => false
62+
case that: LibSVMRelation =>
63+
path == that.path &&
64+
numFeatures == that.numFeatures &&
65+
vectorType == that.vectorType
66+
case _ =>
67+
false
6468
}
65-
6669
}
6770

6871
/**
69-
* This is used for creating DataFrame from LibSVM format file.
70-
* The LibSVM file path must be specified to DefaultSource.
72+
* `libsvm` package implements Spark SQL data source API for loading LIBSVM data as [[DataFrame]].
73+
* The loaded [[DataFrame]] has two columns: `label` containing labels stored as doubles and
74+
* `features` containing feature vectors stored as [[Vector]]s.
75+
*
76+
* To use LIBSVM data source, you need to set "libsvm" as the format in [[DataFrameReader]] and
77+
* optionally specify options, for example:
78+
* {{{
79+
* // Scala
80+
* val df = sqlContext.read.format("libsvm")
81+
* .option("numFeatures", "780")
82+
* .load("data/mllib/sample_libsvm_data.txt")
83+
*
84+
* // Java
85+
* DataFrame df = sqlContext.read.format("libsvm")
86+
* .option("numFeatures, "780")
87+
* .load("data/mllib/sample_libsvm_data.txt");
88+
* }}}
89+
*
90+
* LIBSVM data source supports the following options:
91+
* - "numFeatures": number of features.
92+
* If unspecified or nonpositive, the number of features will be determined automatically at the
93+
* cost of one additional pass.
94+
* This is also useful when the dataset is already split into multiple files and you want to load
95+
* them separately, because some features may not present in certain files, which leads to
96+
* inconsistent feature dimensions.
97+
* - "vectorType": feature vector type, "sparse" (default) or "dense".
98+
*
99+
* @see [[https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/ LIBSVM datasets]]
71100
*/
72101
@Since("1.6.0")
73102
class DefaultSource extends RelationProvider with DataSourceRegister {
74103

75104
@Since("1.6.0")
76105
override def shortName(): String = "libsvm"
77106

78-
private def checkPath(parameters: Map[String, String]): String = {
79-
require(parameters.contains("path"), "'path' must be specified")
80-
parameters.get("path").get
81-
}
82-
83-
/**
84-
* Returns a new base relation with the given parameters.
85-
* Note: the parameters' keywords are case insensitive and this insensitivity is enforced
86-
* by the Map that is passed to the function.
87-
*/
107+
@Since("1.6.0")
88108
override def createRelation(sqlContext: SQLContext, parameters: Map[String, String])
89109
: BaseRelation = {
90-
val path = checkPath(parameters)
110+
val path = parameters.getOrElse("path",
111+
throw new IllegalArgumentException("'path' must be specified"))
91112
val numFeatures = parameters.getOrElse("numFeatures", "-1").toInt
92-
/**
93-
* featuresType can be selected "dense" or "sparse".
94-
* This parameter decides the type of returned feature vector.
95-
*/
96113
val vectorType = parameters.getOrElse("vectorType", "sparse")
97114
new LibSVMRelation(path, numFeatures, vectorType)(sqlContext)
98115
}

mllib/src/test/java/org/apache/spark/ml/source/JavaLibSVMRelationSuite.java renamed to mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
* limitations under the License.
1616
*/
1717

18-
package org.apache.spark.ml.source;
18+
package org.apache.spark.ml.source.libsvm;
1919

2020
import java.io.File;
2121
import java.io.IOException;
@@ -42,34 +42,34 @@
4242
*/
4343
public class JavaLibSVMRelationSuite {
4444
private transient JavaSparkContext jsc;
45-
private transient SQLContext jsql;
46-
private transient DataFrame dataset;
45+
private transient SQLContext sqlContext;
4746

48-
private File tmpDir;
49-
private File path;
47+
private File tempDir;
48+
private String path;
5049

5150
@Before
5251
public void setUp() throws IOException {
5352
jsc = new JavaSparkContext("local", "JavaLibSVMRelationSuite");
54-
jsql = new SQLContext(jsc);
55-
56-
tmpDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource");
57-
path = new File(tmpDir.getPath(), "part-00000");
53+
sqlContext = new SQLContext(jsc);
5854

55+
tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource");
56+
File file = new File(tempDir, "part-00000");
5957
String s = "1 1:1.0 3:2.0 5:3.0\n0\n0 2:4.0 4:5.0 6:6.0";
60-
Files.write(s, path, Charsets.US_ASCII);
58+
Files.write(s, file, Charsets.US_ASCII);
59+
path = tempDir.toURI().toString();
6160
}
6261

6362
@After
6463
public void tearDown() {
6564
jsc.stop();
6665
jsc = null;
67-
Utils.deleteRecursively(tmpDir);
66+
Utils.deleteRecursively(tempDir);
6867
}
6968

7069
@Test
7170
public void verifyLibSVMDF() {
72-
dataset = jsql.read().format("libsvm").option("vectorType", "dense").load(path.getPath());
71+
DataFrame dataset = sqlContext.read().format("libsvm").option("vectorType", "dense")
72+
.load(path);
7373
Assert.assertEquals("label", dataset.columns()[0]);
7474
Assert.assertEquals("features", dataset.columns()[1]);
7575
Row r = dataset.first();

mllib/src/test/scala/org/apache/spark/ml/source/LibSVMRelationSuite.scala renamed to mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,20 @@
1515
* limitations under the License.
1616
*/
1717

18-
package org.apache.spark.ml.source
18+
package org.apache.spark.ml.source.libsvm
1919

2020
import java.io.File
2121

2222
import com.google.common.base.Charsets
2323
import com.google.common.io.Files
2424

2525
import org.apache.spark.SparkFunSuite
26-
import org.apache.spark.mllib.linalg.{SparseVector, Vectors, DenseVector}
26+
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors}
2727
import org.apache.spark.mllib.util.MLlibTestSparkContext
2828
import org.apache.spark.util.Utils
2929

3030
class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext {
31+
var tempDir: File = _
3132
var path: String = _
3233

3334
override def beforeAll(): Unit = {
@@ -38,12 +39,17 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext {
3839
|0
3940
|0 2:4.0 4:5.0 6:6.0
4041
""".stripMargin
41-
val tempDir = Utils.createTempDir()
42-
val file = new File(tempDir.getPath, "part-00000")
42+
tempDir = Utils.createTempDir()
43+
val file = new File(tempDir, "part-00000")
4344
Files.write(lines, file, Charsets.US_ASCII)
4445
path = tempDir.toURI.toString
4546
}
4647

48+
override def afterAll(): Unit = {
49+
Utils.deleteRecursively(tempDir)
50+
super.afterAll()
51+
}
52+
4753
test("select as sparse vector") {
4854
val df = sqlContext.read.format("libsvm").load(path)
4955
assert(df.columns(0) == "label")

0 commit comments

Comments
 (0)