|
17 | 17 |
|
18 | 18 | package org.apache.spark.mllib.util |
19 | 19 |
|
| 20 | +import scala.reflect.runtime.universe.TypeTag |
| 21 | + |
20 | 22 | import org.apache.spark.SparkContext |
21 | 23 | import org.apache.spark.annotation.DeveloperApi |
| 24 | +import org.apache.spark.sql.catalyst.ScalaReflection |
| 25 | +import org.apache.spark.sql.types.{DataType, StructType, StructField} |
| 26 | +import org.apache.spark.sql.{DataFrame, Row, SQLContext} |
22 | 27 |
|
23 | 28 | /** |
24 | 29 | * :: DeveloperApi :: |
@@ -46,11 +51,7 @@ trait Exportable { |
46 | 51 |
|
47 | 52 | } |
48 | 53 |
|
49 | | -/** |
50 | | - * :: DeveloperApi :: |
51 | | - */ |
52 | | -@DeveloperApi |
53 | | -object Exportable { |
| 54 | +private[mllib] object Exportable { |
54 | 55 |
|
55 | 56 | /** Current version of model import/export format. */ |
56 | 57 | val latestVersion: String = "1.0" |
@@ -79,34 +80,32 @@ trait Importable[Model <: Exportable] { |
79 | 80 |
|
80 | 81 | } |
81 | 82 |
|
82 | | -/* |
83 | | -/** |
84 | | - * :: DeveloperApi :: |
85 | | - * |
86 | | - * Trait for models and transformers which may be saved as files. |
87 | | - * This should be inherited by the class which implements model instances. |
88 | | - * |
89 | | - * This specializes [[Exportable]] for local models which can be stored on a single machine. |
90 | | - * This provides helper functionality, but developers can choose to use [[Exportable]] instead, |
91 | | - * even for local models. |
92 | | - */ |
93 | | -@DeveloperApi |
94 | | -trait LocalExportable { |
| 83 | +private[mllib] object Importable { |
95 | 84 |
|
96 | 85 | /** |
97 | | - * Save this model to the given path. |
98 | | - * |
99 | | - * This saves: |
100 | | - * - human-readable (JSON) model metadata to path/metadata/ |
101 | | - * - Parquet formatted data to path/data/ |
| 86 | + * Check the schema of loaded model data. |
102 | 87 | * |
103 | | - * The model may be loaded using [[Importable.load]]. |
| 88 | + * This checks every field in the expected schema to make sure that a field with the same |
| 89 | + * name and DataType appears in the loaded schema. Note that this does NOT check metadata |
| 90 | + * or containsNull. |
104 | 91 | * |
105 | | - * @param sc Spark context used to save model data. |
106 | | - * @param path Path specifying the directory in which to save this model. |
107 | | - * This directory and any intermediate directory will be created if needed. |
| 92 | + * @param loadedSchema Schema for model data loaded from file. |
| 93 | + * @tparam Data Expected data type from which an expected schema can be derived. |
108 | 94 | */ |
109 | | - def save(sc: SparkContext, path: String): Unit |
| 95 | + def checkSchema[Data: TypeTag](loadedSchema: StructType): Unit = { |
| 96 | + // Check schema explicitly since erasure makes it hard to use match-case for checking. |
| 97 | + val expectedFields: Array[StructField] = |
| 98 | + ScalaReflection.schemaFor[Data].dataType.asInstanceOf[StructType].fields |
| 99 | + val loadedFields: Map[String, DataType] = |
| 100 | + loadedSchema.map(field => field.name -> field.dataType).toMap |
| 101 | + expectedFields.foreach { field => |
| 102 | + assert(loadedFields.contains(field.name), s"Unable to parse model data." + |
| 103 | + s" Expected field with name ${field.name} was missing in loaded schema:" + |
| 104 | + s" ${loadedFields.mkString(", ")}") |
| 105 | + assert(loadedFields(field.name) == field.dataType, |
| 106 | + s"Unable to parse model data. Expected field $field but found field" + |
| 107 | + s" with different type: ${loadedFields(field.name)}") |
| 108 | + } |
| 109 | + } |
110 | 110 |
|
111 | 111 | } |
112 | | -*/ |
0 commit comments