|
15 | 15 | # limitations under the License. |
16 | 16 | # |
17 | 17 |
|
18 | | -from abc import ABCMeta, abstractmethod, abstractproperty |
| 18 | +from pyspark.ml.param import * |
| 19 | +from pyspark.ml.pipeline import * |
19 | 20 |
|
20 | | -from pyspark import SparkContext |
21 | | -from pyspark.sql import SchemaRDD, inherit_doc # TODO: move inherit_doc to Spark Core |
22 | | -from pyspark.ml.param import Param, Params |
23 | | -from pyspark.ml.util import Identifiable |
24 | | - |
25 | | -__all__ = ["Pipeline", "Transformer", "Estimator", "param", "feature", "classification"] |
26 | | - |
27 | | - |
28 | | -def _jvm(): |
29 | | - """ |
30 | | - Returns the JVM view associated with SparkContext. Must be called |
31 | | - after SparkContext is initialized. |
32 | | - """ |
33 | | - jvm = SparkContext._jvm |
34 | | - if jvm: |
35 | | - return jvm |
36 | | - else: |
37 | | - raise AttributeError("Cannot load _jvm from SparkContext. Is SparkContext initialized?") |
38 | | - |
39 | | - |
40 | | -@inherit_doc |
41 | | -class PipelineStage(Params): |
42 | | - """ |
43 | | - A stage in a pipeline, either an :py:class:`Estimator` or a |
44 | | - :py:class:`Transformer`. |
45 | | - """ |
46 | | - |
47 | | - __metaclass__ = ABCMeta |
48 | | - |
49 | | - def __init__(self): |
50 | | - super(PipelineStage, self).__init__() |
51 | | - |
52 | | - |
53 | | -@inherit_doc |
54 | | -class Estimator(PipelineStage): |
55 | | - """ |
56 | | - Abstract class for estimators that fit models to data. |
57 | | - """ |
58 | | - |
59 | | - __metaclass__ = ABCMeta |
60 | | - |
61 | | - def __init__(self): |
62 | | - super(Estimator, self).__init__() |
63 | | - |
64 | | - @abstractmethod |
65 | | - def fit(self, dataset, params={}): |
66 | | - """ |
67 | | - Fits a model to the input dataset with optional parameters. |
68 | | -
|
69 | | - :param dataset: input dataset, which is an instance of |
70 | | - :py:class:`pyspark.sql.SchemaRDD` |
71 | | - :param params: an optional param map that overwrites embedded |
72 | | - params |
73 | | - :returns: fitted model |
74 | | - """ |
75 | | - raise NotImplementedError() |
76 | | - |
77 | | - |
78 | | -@inherit_doc |
79 | | -class Transformer(PipelineStage): |
80 | | - """ |
81 | | - Abstract class for transformers that transform one dataset into |
82 | | - another. |
83 | | - """ |
84 | | - |
85 | | - __metaclass__ = ABCMeta |
86 | | - |
87 | | - def __init__(self): |
88 | | - super(Transformer, self).__init__() |
89 | | - |
90 | | - @abstractmethod |
91 | | - def transform(self, dataset, params={}): |
92 | | - """ |
93 | | - Transforms the input dataset with optional parameters. |
94 | | -
|
95 | | - :param dataset: input dataset, which is an instance of |
96 | | - :py:class:`pyspark.sql.SchemaRDD` |
97 | | - :param params: an optional param map that overwrites embedded |
98 | | - params |
99 | | - :returns: transformed dataset |
100 | | - """ |
101 | | - raise NotImplementedError() |
102 | | - |
103 | | - |
104 | | -@inherit_doc |
105 | | -class Model(Transformer): |
106 | | - """ |
107 | | - Abstract class for models fitted by :py:class:`Estimator`s. |
108 | | - """ |
109 | | - |
110 | | - ___metaclass__ = ABCMeta |
111 | | - |
112 | | - def __init__(self): |
113 | | - super(Model, self).__init__() |
114 | | - |
115 | | - |
116 | | -@inherit_doc |
117 | | -class Pipeline(Estimator): |
118 | | - """ |
119 | | - A simple pipeline, which acts as an estimator. A Pipeline consists |
120 | | - of a sequence of stages, each of which is either an |
121 | | - :py:class:`Estimator` or a :py:class:`Transformer`. When |
122 | | - :py:meth:`Pipeline.fit` is called, the stages are executed in |
123 | | - order. If a stage is an :py:class:`Estimator`, its |
124 | | - :py:meth:`Estimator.fit` method will be called on the input |
125 | | - dataset to fit a model. Then the model, which is a transformer, |
126 | | - will be used to transform the dataset as the input to the next |
127 | | - stage. If a stage is a :py:class:`Transformer`, its |
128 | | - :py:meth:`Transformer.transform` method will be called to produce |
129 | | - the dataset for the next stage. The fitted model from a |
130 | | - :py:class:`Pipeline` is an :py:class:`PipelineModel`, which |
131 | | - consists of fitted models and transformers, corresponding to the |
132 | | - pipeline stages. If there are no stages, the pipeline acts as an |
133 | | - identity transformer. |
134 | | - """ |
135 | | - |
136 | | - def __init__(self): |
137 | | - super(Pipeline, self).__init__() |
138 | | - #: Param for pipeline stages. |
139 | | - self.stages = Param(self, "stages", "pipeline stages") |
140 | | - |
141 | | - def setStages(self, value): |
142 | | - """ |
143 | | - Set pipeline stages. |
144 | | - :param value: a list of transformers or estimators |
145 | | - :return: the pipeline instance |
146 | | - """ |
147 | | - self.paramMap[self.stages] = value |
148 | | - return self |
149 | | - |
150 | | - def getStages(self): |
151 | | - """ |
152 | | - Get pipeline stages. |
153 | | - """ |
154 | | - if self.stages in self.paramMap: |
155 | | - return self.paramMap[self.stages] |
156 | | - |
157 | | - def fit(self, dataset, params={}): |
158 | | - paramMap = self._merge_params(params) |
159 | | - stages = paramMap[self.stages] |
160 | | - for stage in stages: |
161 | | - if not (isinstance(stage, Estimator) or isinstance(stage, Transformer)): |
162 | | - raise ValueError( |
163 | | - "Cannot recognize a pipeline stage of type %s." % type(stage).__name__) |
164 | | - indexOfLastEstimator = -1 |
165 | | - for i, stage in enumerate(stages): |
166 | | - if isinstance(stage, Estimator): |
167 | | - indexOfLastEstimator = i |
168 | | - transformers = [] |
169 | | - for i, stage in enumerate(stages): |
170 | | - if i <= indexOfLastEstimator: |
171 | | - if isinstance(stage, Transformer): |
172 | | - transformers.append(stage) |
173 | | - dataset = stage.transform(dataset, paramMap) |
174 | | - else: # must be an Estimator |
175 | | - model = stage.fit(dataset, paramMap) |
176 | | - transformers.append(model) |
177 | | - if i < indexOfLastEstimator: |
178 | | - dataset = model.transform(dataset, paramMap) |
179 | | - else: |
180 | | - transformers.append(stage) |
181 | | - return PipelineModel(transformers) |
182 | | - |
183 | | - |
184 | | -@inherit_doc |
185 | | -class PipelineModel(Model): |
186 | | - """ |
187 | | - Represents a compiled pipeline with transformers and fitted models. |
188 | | - """ |
189 | | - |
190 | | - def __init__(self, transformers): |
191 | | - super(PipelineModel, self).__init__() |
192 | | - self.transformers = transformers |
193 | | - |
194 | | - def transform(self, dataset, params={}): |
195 | | - paramMap = self._merge_params(params) |
196 | | - for t in self.transformers: |
197 | | - dataset = t.transform(dataset, paramMap) |
198 | | - return dataset |
199 | | - |
200 | | - |
201 | | -@inherit_doc |
202 | | -class JavaWrapper(Params): |
203 | | - """ |
204 | | - Utility class to help create wrapper classes from Java/Scala |
205 | | - implementations of pipeline components. |
206 | | - """ |
207 | | - |
208 | | - __metaclass__ = ABCMeta |
209 | | - |
210 | | - def __init__(self): |
211 | | - super(JavaWrapper, self).__init__() |
212 | | - |
213 | | - @abstractproperty |
214 | | - def _java_class(self): |
215 | | - """ |
216 | | - Fully-qualified class name of the wrapped Java component. |
217 | | - """ |
218 | | - raise NotImplementedError |
219 | | - |
220 | | - def _java_obj(self): |
221 | | - """ |
222 | | - Returns or creates a Java object. |
223 | | - """ |
224 | | - java_obj = _jvm() |
225 | | - for name in self._java_class.split("."): |
226 | | - java_obj = getattr(java_obj, name) |
227 | | - return java_obj() |
228 | | - |
229 | | - def _transfer_params_to_java(self, params, java_obj): |
230 | | - """ |
231 | | - Transforms the embedded params and additional params to the |
232 | | - input Java object. |
233 | | - :param params: additional params (overwriting embedded values) |
234 | | - :param java_obj: Java object to receive the params |
235 | | - """ |
236 | | - paramMap = self._merge_params(params) |
237 | | - for param in self.params: |
238 | | - if param in paramMap: |
239 | | - java_obj.set(param.name, paramMap[param]) |
240 | | - |
241 | | - def _empty_java_param_map(self): |
242 | | - """ |
243 | | - Returns an empty Java ParamMap reference. |
244 | | - """ |
245 | | - return _jvm().org.apache.spark.ml.param.ParamMap() |
246 | | - |
247 | | - def _create_java_param_map(self, params, java_obj): |
248 | | - paramMap = self._empty_java_param_map() |
249 | | - for param, value in params.items(): |
250 | | - if param.parent is self: |
251 | | - paramMap.put(java_obj.getParam(param.name), value) |
252 | | - return paramMap |
253 | | - |
254 | | - |
255 | | -@inherit_doc |
256 | | -class JavaEstimator(Estimator, JavaWrapper): |
257 | | - """ |
258 | | - Base class for :py:class:`Estimator`s that wrap Java/Scala |
259 | | - implementations. |
260 | | - """ |
261 | | - |
262 | | - __metaclass__ = ABCMeta |
263 | | - |
264 | | - def __init__(self): |
265 | | - super(JavaEstimator, self).__init__() |
266 | | - |
267 | | - @abstractmethod |
268 | | - def _create_model(self, java_model): |
269 | | - """ |
270 | | - Creates a model from the input Java model reference. |
271 | | - """ |
272 | | - raise NotImplementedError |
273 | | - |
274 | | - def _fit_java(self, dataset, params={}): |
275 | | - """ |
276 | | - Fits a Java model to the input dataset. |
277 | | - :param dataset: input dataset, which is an instance of |
278 | | - :py:class:`pyspark.sql.SchemaRDD` |
279 | | - :param params: additional params (overwriting embedded values) |
280 | | - :return: fitted Java model |
281 | | - """ |
282 | | - java_obj = self._java_obj() |
283 | | - self._transfer_params_to_java(params, java_obj) |
284 | | - return java_obj.fit(dataset._jschema_rdd, self._empty_java_param_map()) |
285 | | - |
286 | | - def fit(self, dataset, params={}): |
287 | | - java_model = self._fit_java(dataset, params) |
288 | | - return self._create_model(java_model) |
289 | | - |
290 | | - |
291 | | -@inherit_doc |
292 | | -class JavaTransformer(Transformer, JavaWrapper): |
293 | | - """ |
294 | | - Base class for :py:class:`Transformer`s that wrap Java/Scala |
295 | | - implementations. |
296 | | - """ |
297 | | - |
298 | | - __metaclass__ = ABCMeta |
299 | | - |
300 | | - def __init__(self): |
301 | | - super(JavaTransformer, self).__init__() |
302 | | - |
303 | | - def transform(self, dataset, params={}): |
304 | | - java_obj = self._java_obj() |
305 | | - self._transfer_params_to_java({}, java_obj) |
306 | | - java_param_map = self._create_java_param_map(params, java_obj) |
307 | | - return SchemaRDD(java_obj.transform(dataset._jschema_rdd, java_param_map), |
308 | | - dataset.sql_ctx) |
309 | | - |
310 | | - |
311 | | -@inherit_doc |
312 | | -class JavaModel(JavaTransformer): |
313 | | - """ |
314 | | - Base class for :py:class:`Model`s that wrap Java/Scala |
315 | | - implementations. |
316 | | - """ |
317 | | - |
318 | | - __metaclass__ = ABCMeta |
319 | | - |
320 | | - def __init__(self): |
321 | | - super(JavaTransformer, self).__init__() |
322 | | - |
323 | | - def _java_obj(self): |
324 | | - return self._java_model |
| 21 | +__all__ = ["Param", "Params", "Transformer", "Estimator", "Pipeline"] |
0 commit comments