Skip to content

Commit e7e3bfd

Browse files
mengxrjkbradley
authored andcommitted
[SPARK-11217][ML] save/load for non-meta estimators and transformers
This PR implements the default save/load for non-meta estimators and transformers using the JSON serialization of param values. The saved metadata includes: * class name * uid * timestamp * paramMap The save/load interface is similar to DataFrames. We use the current active context by default, which should be sufficient for most use cases. ~~~scala instance.save("path") instance.write.context(sqlContext).overwrite().save("path") Instance.load("path") ~~~ The param handling is different from the design doc. We didn't save default and user-set params separately, and when we load it back, all parameters are user-set. This does cause issues. But it also cause other issues if we modify the default params. TODOs: * [x] Java test * [ ] a follow-up PR to implement default save/load for all non-meta estimators and transformers cc jkbradley Author: Xiangrui Meng <[email protected]> Closes #9454 from mengxr/SPARK-11217. (cherry picked from commit c447c9d) Signed-off-by: Joseph K. Bradley <[email protected]>
1 parent 52e921c commit e7e3bfd

File tree

7 files changed

+469
-4
lines changed

7 files changed

+469
-4
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import org.apache.spark.ml.Transformer
2222
import org.apache.spark.ml.attribute.BinaryAttribute
2323
import org.apache.spark.ml.param._
2424
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
25-
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
25+
import org.apache.spark.ml.util._
2626
import org.apache.spark.sql._
2727
import org.apache.spark.sql.functions._
2828
import org.apache.spark.sql.types.{DoubleType, StructType}
@@ -33,7 +33,7 @@ import org.apache.spark.sql.types.{DoubleType, StructType}
3333
*/
3434
@Experimental
3535
final class Binarizer(override val uid: String)
36-
extends Transformer with HasInputCol with HasOutputCol {
36+
extends Transformer with Writable with HasInputCol with HasOutputCol {
3737

3838
def this() = this(Identifiable.randomUID("binarizer"))
3939

@@ -86,4 +86,11 @@ final class Binarizer(override val uid: String)
8686
}
8787

8888
override def copy(extra: ParamMap): Binarizer = defaultCopy(extra)
89+
90+
override def write: Writer = new DefaultParamsWriter(this)
91+
}
92+
93+
object Binarizer extends Readable[Binarizer] {
94+
95+
override def read: Reader[Binarizer] = new DefaultParamsReader[Binarizer]
8996
}

mllib/src/main/scala/org/apache/spark/ml/param/params.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -592,7 +592,7 @@ trait Params extends Identifiable with Serializable {
592592
/**
593593
* Sets a parameter in the embedded param map.
594594
*/
595-
protected final def set[T](param: Param[T], value: T): this.type = {
595+
final def set[T](param: Param[T], value: T): this.type = {
596596
set(param -> value)
597597
}
598598

Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml.util
19+
20+
import java.io.IOException
21+
22+
import org.apache.hadoop.fs.{FileSystem, Path}
23+
import org.json4s._
24+
import org.json4s.JsonDSL._
25+
import org.json4s.jackson.JsonMethods._
26+
27+
import org.apache.spark.{Logging, SparkContext}
28+
import org.apache.spark.annotation.{Experimental, Since}
29+
import org.apache.spark.ml.param.{ParamPair, Params}
30+
import org.apache.spark.sql.SQLContext
31+
import org.apache.spark.util.Utils
32+
33+
/**
34+
* Trait for [[Writer]] and [[Reader]].
35+
*/
36+
private[util] sealed trait BaseReadWrite {
37+
private var optionSQLContext: Option[SQLContext] = None
38+
39+
/**
40+
* Sets the SQL context to use for saving/loading.
41+
*/
42+
@Since("1.6.0")
43+
def context(sqlContext: SQLContext): this.type = {
44+
optionSQLContext = Option(sqlContext)
45+
this
46+
}
47+
48+
/**
49+
* Returns the user-specified SQL context or the default.
50+
*/
51+
protected final def sqlContext: SQLContext = optionSQLContext.getOrElse {
52+
SQLContext.getOrCreate(SparkContext.getOrCreate())
53+
}
54+
}
55+
56+
/**
57+
* Abstract class for utility classes that can save ML instances.
58+
*/
59+
@Experimental
60+
@Since("1.6.0")
61+
abstract class Writer extends BaseReadWrite {
62+
63+
protected var shouldOverwrite: Boolean = false
64+
65+
/**
66+
* Saves the ML instances to the input path.
67+
*/
68+
@Since("1.6.0")
69+
@throws[IOException]("If the input path already exists but overwrite is not enabled.")
70+
def save(path: String): Unit
71+
72+
/**
73+
* Overwrites if the output path already exists.
74+
*/
75+
@Since("1.6.0")
76+
def overwrite(): this.type = {
77+
shouldOverwrite = true
78+
this
79+
}
80+
81+
// override for Java compatibility
82+
override def context(sqlContext: SQLContext): this.type = super.context(sqlContext)
83+
}
84+
85+
/**
86+
* Trait for classes that provide [[Writer]].
87+
*/
88+
@Since("1.6.0")
89+
trait Writable {
90+
91+
/**
92+
* Returns a [[Writer]] instance for this ML instance.
93+
*/
94+
@Since("1.6.0")
95+
def write: Writer
96+
97+
/**
98+
* Saves this ML instance to the input path, a shortcut of `write.save(path)`.
99+
*/
100+
@Since("1.6.0")
101+
@throws[IOException]("If the input path already exists but overwrite is not enabled.")
102+
def save(path: String): Unit = write.save(path)
103+
}
104+
105+
/**
106+
* Abstract class for utility classes that can load ML instances.
107+
* @tparam T ML instance type
108+
*/
109+
@Experimental
110+
@Since("1.6.0")
111+
abstract class Reader[T] extends BaseReadWrite {
112+
113+
/**
114+
* Loads the ML component from the input path.
115+
*/
116+
@Since("1.6.0")
117+
def load(path: String): T
118+
119+
// override for Java compatibility
120+
override def context(sqlContext: SQLContext): this.type = super.context(sqlContext)
121+
}
122+
123+
/**
124+
* Trait for objects that provide [[Reader]].
125+
* @tparam T ML instance type
126+
*/
127+
@Experimental
128+
@Since("1.6.0")
129+
trait Readable[T] {
130+
131+
/**
132+
* Returns a [[Reader]] instance for this class.
133+
*/
134+
@Since("1.6.0")
135+
def read: Reader[T]
136+
137+
/**
138+
* Reads an ML instance from the input path, a shortcut of `read.load(path)`.
139+
*/
140+
@Since("1.6.0")
141+
def load(path: String): T = read.load(path)
142+
}
143+
144+
/**
145+
* Default [[Writer]] implementation for transformers and estimators that contain basic
146+
* (json4s-serializable) params and no data. This will not handle more complex params or types with
147+
* data (e.g., models with coefficients).
148+
* @param instance object to save
149+
*/
150+
private[ml] class DefaultParamsWriter(instance: Params) extends Writer with Logging {
151+
152+
/**
153+
* Saves the ML component to the input path.
154+
*/
155+
override def save(path: String): Unit = {
156+
val sc = sqlContext.sparkContext
157+
158+
val hadoopConf = sc.hadoopConfiguration
159+
val fs = FileSystem.get(hadoopConf)
160+
val p = new Path(path)
161+
if (fs.exists(p)) {
162+
if (shouldOverwrite) {
163+
logInfo(s"Path $path already exists. It will be overwritten.")
164+
// TODO: Revert back to the original content if save is not successful.
165+
fs.delete(p, true)
166+
} else {
167+
throw new IOException(
168+
s"Path $path already exists. Please use write.overwrite().save(path) to overwrite it.")
169+
}
170+
}
171+
172+
val uid = instance.uid
173+
val cls = instance.getClass.getName
174+
val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]]
175+
val jsonParams = params.map { case ParamPair(p, v) =>
176+
p.name -> parse(p.jsonEncode(v))
177+
}.toList
178+
val metadata = ("class" -> cls) ~
179+
("timestamp" -> System.currentTimeMillis()) ~
180+
("uid" -> uid) ~
181+
("paramMap" -> jsonParams)
182+
val metadataPath = new Path(path, "metadata").toString
183+
val metadataJson = compact(render(metadata))
184+
sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath)
185+
}
186+
}
187+
188+
/**
189+
* Default [[Reader]] implementation for transformers and estimators that contain basic
190+
* (json4s-serializable) params and no data. This will not handle more complex params or types with
191+
* data (e.g., models with coefficients).
192+
* @tparam T ML instance type
193+
*/
194+
private[ml] class DefaultParamsReader[T] extends Reader[T] {
195+
196+
/**
197+
* Loads the ML component from the input path.
198+
*/
199+
override def load(path: String): T = {
200+
implicit val format = DefaultFormats
201+
val sc = sqlContext.sparkContext
202+
val metadataPath = new Path(path, "metadata").toString
203+
val metadataStr = sc.textFile(metadataPath, 1).first()
204+
val metadata = parse(metadataStr)
205+
val cls = Utils.classForName((metadata \ "class").extract[String])
206+
val uid = (metadata \ "uid").extract[String]
207+
val instance = cls.getConstructor(classOf[String]).newInstance(uid).asInstanceOf[Params]
208+
(metadata \ "paramMap") match {
209+
case JObject(pairs) =>
210+
pairs.foreach { case (paramName, jsonValue) =>
211+
val param = instance.getParam(paramName)
212+
val value = param.jsonDecode(compact(render(jsonValue)))
213+
instance.set(param, value)
214+
}
215+
case _ =>
216+
throw new IllegalArgumentException(s"Cannot recognize JSON metadata: $metadataStr.")
217+
}
218+
instance.asInstanceOf[T]
219+
}
220+
}
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml.util;
19+
20+
import java.io.File;
21+
import java.io.IOException;
22+
23+
import org.junit.After;
24+
import org.junit.Assert;
25+
import org.junit.Before;
26+
import org.junit.Test;
27+
28+
import org.apache.spark.api.java.JavaSparkContext;
29+
import org.apache.spark.sql.SQLContext;
30+
import org.apache.spark.util.Utils;
31+
32+
public class JavaDefaultReadWriteSuite {
33+
34+
JavaSparkContext jsc = null;
35+
File tempDir = null;
36+
37+
@Before
38+
public void setUp() {
39+
jsc = new JavaSparkContext("local[2]", "JavaDefaultReadWriteSuite");
40+
tempDir = Utils.createTempDir(
41+
System.getProperty("java.io.tmpdir"), "JavaDefaultReadWriteSuite");
42+
}
43+
44+
@After
45+
public void tearDown() {
46+
if (jsc != null) {
47+
jsc.stop();
48+
jsc = null;
49+
}
50+
Utils.deleteRecursively(tempDir);
51+
}
52+
53+
@Test
54+
public void testDefaultReadWrite() throws IOException {
55+
String uid = "my_params";
56+
MyParams instance = new MyParams(uid);
57+
instance.set(instance.intParam(), 2);
58+
String outputPath = new File(tempDir, uid).getPath();
59+
instance.save(outputPath);
60+
try {
61+
instance.save(outputPath);
62+
Assert.fail(
63+
"Write without overwrite enabled should fail if the output directory already exists.");
64+
} catch (IOException e) {
65+
// expected
66+
}
67+
SQLContext sqlContext = new SQLContext(jsc);
68+
instance.write().context(sqlContext).overwrite().save(outputPath);
69+
MyParams newInstance = MyParams.load(outputPath);
70+
Assert.assertEquals("UID should match.", instance.uid(), newInstance.uid());
71+
Assert.assertEquals("Params should be preserved.",
72+
2, newInstance.getOrDefault(newInstance.intParam()));
73+
}
74+
}

mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,11 @@ package org.apache.spark.ml.feature
1919

2020
import org.apache.spark.SparkFunSuite
2121
import org.apache.spark.ml.param.ParamsSuite
22+
import org.apache.spark.ml.util.DefaultReadWriteTest
2223
import org.apache.spark.mllib.util.MLlibTestSparkContext
2324
import org.apache.spark.sql.{DataFrame, Row}
2425

25-
class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext {
26+
class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
2627

2728
@transient var data: Array[Double] = _
2829

@@ -66,4 +67,12 @@ class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext {
6667
assert(x === y, "The feature value is not correct after binarization.")
6768
}
6869
}
70+
71+
test("read/write") {
72+
val binarizer = new Binarizer()
73+
.setInputCol("feature")
74+
.setOutputCol("binarized_feature")
75+
.setThreshold(0.1)
76+
testDefaultReadWrite(binarizer)
77+
}
6978
}

0 commit comments

Comments
 (0)