|
17 | 17 |
|
18 | 18 | package org.apache.spark.ml.feature |
19 | 19 |
|
| 20 | +import org.apache.spark.SparkContext |
20 | 21 | import org.apache.spark.annotation.Experimental |
21 | 22 | import org.apache.spark.ml.param.{ParamMap, Param} |
22 | 23 | import org.apache.spark.ml.Transformer |
23 | 24 | import org.apache.spark.ml.util.Identifiable |
24 | | -import org.apache.spark.sql.DataFrame |
25 | | -import org.apache.spark.sql.Row |
| 25 | +import org.apache.spark.sql.{SQLContext, DataFrame, Row} |
26 | 26 | import org.apache.spark.sql.types.StructType |
27 | 27 |
|
28 | 28 | /** |
@@ -50,31 +50,30 @@ class SQLTransformer (override val uid: String) extends Transformer { |
50 | 50 |
|
51 | 51 | private val tableIdentifier: String = "__THIS__" |
52 | 52 |
|
53 | | - /** |
54 | | - * The output schema of this transformer. |
55 | | - * It is only valid after transform function has been called. |
56 | | - */ |
57 | | - private var outputSchema: StructType = null |
58 | | - |
59 | 53 | override def transform(dataset: DataFrame): DataFrame = { |
60 | | - val tableName = uid |
61 | | - val realStatement = $(statement).replace(tableIdentifier, tableName) |
| 54 | + val outputSchema = transformSchema(dataset.schema, logging = true) |
| 55 | + val tableName = Identifiable.randomUID("sql") |
62 | 56 | dataset.registerTempTable(tableName) |
63 | | - val originalSchema = dataset.schema |
| 57 | + val realStatement = $(statement).replace(tableIdentifier, tableName) |
64 | 58 | val additiveDF = dataset.sqlContext.sql(realStatement) |
65 | | - val additiveSchema = additiveDF.schema |
66 | | - outputSchema = StructType(Array.concat(originalSchema.fields, additiveSchema.fields)) |
67 | | - additiveSchema.fieldNames.foreach { |
68 | | - case name => |
69 | | - require(!originalSchema.fieldNames.contains(name), s"Output column $name already exists.") |
70 | | - } |
71 | 59 | val rdd = dataset.rdd.zip(additiveDF.rdd).map { |
72 | 60 | case (r1, r2) => Row.merge(r1, r2) |
73 | 61 | } |
74 | 62 | dataset.sqlContext.createDataFrame(rdd, outputSchema) |
75 | 63 | } |
76 | 64 |
|
77 | 65 | override def transformSchema(schema: StructType): StructType = { |
| 66 | + val sc = SparkContext.getOrCreate() |
| 67 | + val sqlContext = SQLContext.getOrCreate(sc) |
| 68 | + val dummyRDD = sc.parallelize(Seq(Row.empty)) |
| 69 | + val dummyDF = sqlContext.createDataFrame(dummyRDD, schema) |
| 70 | + dummyDF.registerTempTable(tableIdentifier) |
| 71 | + val additiveSchema = sqlContext.sql($(statement)).schema |
| 72 | + additiveSchema.fieldNames.foreach { |
| 73 | + case name => |
| 74 | + require(!schema.fieldNames.contains(name), s"Output column $name already exists.") |
| 75 | + } |
| 76 | + val outputSchema = StructType(Array.concat(schema.fields, additiveSchema.fields)) |
78 | 77 | outputSchema |
79 | 78 | } |
80 | 79 |
|
|
0 commit comments