Skip to content

Commit 0d4bb15

Browse files
committed
a better transformSchema() implementation
1 parent 51eb9e7 commit 0d4bb15

File tree

1 file changed

+16
-17
lines changed

1 file changed

+16
-17
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,12 @@
1717

1818
package org.apache.spark.ml.feature
1919

20+
import org.apache.spark.SparkContext
2021
import org.apache.spark.annotation.Experimental
2122
import org.apache.spark.ml.param.{ParamMap, Param}
2223
import org.apache.spark.ml.Transformer
2324
import org.apache.spark.ml.util.Identifiable
24-
import org.apache.spark.sql.DataFrame
25-
import org.apache.spark.sql.Row
25+
import org.apache.spark.sql.{SQLContext, DataFrame, Row}
2626
import org.apache.spark.sql.types.StructType
2727

2828
/**
@@ -50,31 +50,30 @@ class SQLTransformer (override val uid: String) extends Transformer {
5050

5151
private val tableIdentifier: String = "__THIS__"
5252

53-
/**
54-
* The output schema of this transformer.
55-
* It is only valid after transform function has been called.
56-
*/
57-
private var outputSchema: StructType = null
58-
5953
override def transform(dataset: DataFrame): DataFrame = {
60-
val tableName = uid
61-
val realStatement = $(statement).replace(tableIdentifier, tableName)
54+
val outputSchema = transformSchema(dataset.schema, logging = true)
55+
val tableName = Identifiable.randomUID("sql")
6256
dataset.registerTempTable(tableName)
63-
val originalSchema = dataset.schema
57+
val realStatement = $(statement).replace(tableIdentifier, tableName)
6458
val additiveDF = dataset.sqlContext.sql(realStatement)
65-
val additiveSchema = additiveDF.schema
66-
outputSchema = StructType(Array.concat(originalSchema.fields, additiveSchema.fields))
67-
additiveSchema.fieldNames.foreach {
68-
case name =>
69-
require(!originalSchema.fieldNames.contains(name), s"Output column $name already exists.")
70-
}
7159
val rdd = dataset.rdd.zip(additiveDF.rdd).map {
7260
case (r1, r2) => Row.merge(r1, r2)
7361
}
7462
dataset.sqlContext.createDataFrame(rdd, outputSchema)
7563
}
7664

7765
override def transformSchema(schema: StructType): StructType = {
66+
val sc = SparkContext.getOrCreate()
67+
val sqlContext = SQLContext.getOrCreate(sc)
68+
val dummyRDD = sc.parallelize(Seq(Row.empty))
69+
val dummyDF = sqlContext.createDataFrame(dummyRDD, schema)
70+
dummyDF.registerTempTable(tableIdentifier)
71+
val additiveSchema = sqlContext.sql($(statement)).schema
72+
additiveSchema.fieldNames.foreach {
73+
case name =>
74+
require(!schema.fieldNames.contains(name), s"Output column $name already exists.")
75+
}
76+
val outputSchema = StructType(Array.concat(schema.fields, additiveSchema.fields))
7877
outputSchema
7978
}
8079

0 commit comments

Comments
 (0)