Skip to content

Commit 79bb30c

Browse files
committed
cleanup OutputWriterFactory and OutputWriter
1 parent 30345c4 commit 79bb30c

File tree

12 files changed

+33
-66
lines changed

12 files changed

+33
-66
lines changed

mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,12 @@ private[libsvm] class LibSVMOutputWriter(
4545

4646
private val writer = CodecStreams.createOutputStreamWriter(context, new Path(path))
4747

48-
override def write(row: Row): Unit = {
49-
val label = row.get(0)
50-
val vector = row.get(1).asInstanceOf[Vector]
48+
// This `asInstanceOf` is safe because it's guaranteed by `LibSVMFileFormat.verifySchema`
49+
private val udt = dataSchema(1).dataType.asInstanceOf[VectorUDT]
50+
51+
override def write(row: InternalRow): Unit = {
52+
val label = row.getDouble(0)
53+
val vector = udt.deserialize(row.getStruct(1, udt.sqlType.length))
5154
writer.write(label.toString)
5255
vector.foreachActive { case (i, v) =>
5356
writer.write(s" ${i + 1}:$v")

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -466,7 +466,7 @@ case class DataSource(
466466
// SPARK-17230: Resolve the partition columns so InsertIntoHadoopFsRelationCommand does
467467
// not need to have the query as child, to avoid to analyze an optimized query,
468468
// because InsertIntoHadoopFsRelationCommand will be optimized first.
469-
val columns = partitionColumns.map { name =>
469+
val partitionAttributes = partitionColumns.map { name =>
470470
val plan = data.logicalPlan
471471
plan.resolve(name :: Nil, data.sparkSession.sessionState.analyzer.resolver).getOrElse {
472472
throw new AnalysisException(
@@ -485,7 +485,7 @@ case class DataSource(
485485
InsertIntoHadoopFsRelationCommand(
486486
outputPath = outputPath,
487487
staticPartitions = Map.empty,
488-
partitionColumns = columns,
488+
partitionColumns = partitionAttributes,
489489
bucketSpec = bucketSpec,
490490
fileFormat = format,
491491
options = options,

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -64,18 +64,18 @@ object FileFormatWriter extends Logging {
6464
val outputWriterFactory: OutputWriterFactory,
6565
val allColumns: Seq[Attribute],
6666
val partitionColumns: Seq[Attribute],
67-
val nonPartitionColumns: Seq[Attribute],
67+
val dataColumns: Seq[Attribute],
6868
val bucketSpec: Option[BucketSpec],
6969
val path: String,
7070
val customPartitionLocations: Map[TablePartitionSpec, String],
7171
val maxRecordsPerFile: Long)
7272
extends Serializable {
7373

74-
assert(AttributeSet(allColumns) == AttributeSet(partitionColumns ++ nonPartitionColumns),
74+
assert(AttributeSet(allColumns) == AttributeSet(partitionColumns ++ dataColumns),
7575
s"""
7676
|All columns: ${allColumns.mkString(", ")}
7777
|Partition columns: ${partitionColumns.mkString(", ")}
78-
|Non-partition columns: ${nonPartitionColumns.mkString(", ")}
78+
|Data columns: ${dataColumns.mkString(", ")}
7979
""".stripMargin)
8080
}
8181

@@ -120,7 +120,7 @@ object FileFormatWriter extends Logging {
120120
outputWriterFactory = outputWriterFactory,
121121
allColumns = queryExecution.logical.output,
122122
partitionColumns = partitionColumns,
123-
nonPartitionColumns = dataColumns,
123+
dataColumns = dataColumns,
124124
bucketSpec = bucketSpec,
125125
path = outputSpec.outputPath,
126126
customPartitionLocations = outputSpec.customPartitionLocations,
@@ -246,9 +246,8 @@ object FileFormatWriter extends Logging {
246246

247247
currentWriter = description.outputWriterFactory.newInstance(
248248
path = tmpFilePath,
249-
dataSchema = description.nonPartitionColumns.toStructType,
249+
dataSchema = description.dataColumns.toStructType,
250250
context = taskAttemptContext)
251-
currentWriter.initConverter(dataSchema = description.nonPartitionColumns.toStructType)
252251
}
253252

254253
override def execute(iter: Iterator[InternalRow]): Set[String] = {
@@ -267,7 +266,7 @@ object FileFormatWriter extends Logging {
267266
}
268267

269268
val internalRow = iter.next()
270-
currentWriter.writeInternal(internalRow)
269+
currentWriter.write(internalRow)
271270
recordsInFile += 1
272271
}
273272
releaseResources()
@@ -364,9 +363,8 @@ object FileFormatWriter extends Logging {
364363

365364
currentWriter = description.outputWriterFactory.newInstance(
366365
path = path,
367-
dataSchema = description.nonPartitionColumns.toStructType,
366+
dataSchema = description.dataColumns.toStructType,
368367
context = taskAttemptContext)
369-
currentWriter.initConverter(description.nonPartitionColumns.toStructType)
370368
}
371369

372370
override def execute(iter: Iterator[InternalRow]): Set[String] = {
@@ -383,7 +381,7 @@ object FileFormatWriter extends Logging {
383381

384382
// Returns the data columns to be written given an input row
385383
val getOutputRow = UnsafeProjection.create(
386-
description.nonPartitionColumns, description.allColumns)
384+
description.dataColumns, description.allColumns)
387385

388386
// Returns the partition path given a partition key.
389387
val getPartitionStringFunc = UnsafeProjection.create(
@@ -392,7 +390,7 @@ object FileFormatWriter extends Logging {
392390
// Sorts the data before write, so that we only need one writer at the same time.
393391
val sorter = new UnsafeKVExternalSorter(
394392
sortingKeySchema,
395-
StructType.fromAttributes(description.nonPartitionColumns),
393+
StructType.fromAttributes(description.dataColumns),
396394
SparkEnv.get.blockManager,
397395
SparkEnv.get.serializerManager,
398396
TaskContext.get().taskMemoryManager().pageSizeBytes,
@@ -448,7 +446,7 @@ object FileFormatWriter extends Logging {
448446
newOutputWriter(currentKey, getPartitionStringFunc, fileCounter)
449447
}
450448

451-
currentWriter.writeInternal(sortedIterator.getValue)
449+
currentWriter.write(sortedIterator.getValue)
452450
recordsInFile += 1
453451
}
454452
releaseResources()

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ case class InsertIntoHadoopFsRelationCommand(
4545
bucketSpec: Option[BucketSpec],
4646
fileFormat: FileFormat,
4747
options: Map[String, String],
48-
@transient query: LogicalPlan,
48+
query: LogicalPlan,
4949
mode: SaveMode,
5050
catalogTable: Option[CatalogTable],
5151
fileIndex: Option[FileIndex])

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/OutputWriter.scala

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -47,19 +47,6 @@ abstract class OutputWriterFactory extends Serializable {
4747
path: String,
4848
dataSchema: StructType,
4949
context: TaskAttemptContext): OutputWriter
50-
51-
/**
52-
* Returns a new instance of [[OutputWriter]] that will write data to the given path.
53-
* This method gets called by each task on executor to write InternalRows to
54-
* format-specific files. Compared to the other `newInstance()`, this is a newer API that
55-
* passes only the path that the writer must write to. The writer must write to the exact path
56-
* and not modify it (do not add subdirectories, extensions, etc.). All other
57-
* file-format-specific information needed to create the writer must be passed
58-
* through the [[OutputWriterFactory]] implementation.
59-
*/
60-
def newWriter(path: String): OutputWriter = {
61-
throw new UnsupportedOperationException("newInstance with just path not supported")
62-
}
6350
}
6451

6552

@@ -74,22 +61,11 @@ abstract class OutputWriter {
7461
* Persists a single row. Invoked on the executor side. When writing to dynamically partitioned
7562
* tables, dynamic partition columns are not included in rows to be written.
7663
*/
77-
def write(row: Row): Unit
64+
def write(row: InternalRow): Unit
7865

7966
/**
8067
* Closes the [[OutputWriter]]. Invoked on the executor side after all rows are persisted, before
8168
* the task output is committed.
8269
*/
8370
def close(): Unit
84-
85-
private var converter: InternalRow => Row = _
86-
87-
protected[sql] def initConverter(dataSchema: StructType) = {
88-
converter =
89-
CatalystTypeConverters.createToScalaConverter(dataSchema).asInstanceOf[InternalRow => Row]
90-
}
91-
92-
protected[sql] def writeInternal(row: InternalRow): Unit = {
93-
write(converter(row))
94-
}
9571
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -221,9 +221,7 @@ private[csv] class CsvOutputWriter(
221221
row.get(ordinal, dt).toString
222222
}
223223

224-
override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal")
225-
226-
override protected[sql] def writeInternal(row: InternalRow): Unit = {
224+
override def write(row: InternalRow): Unit = {
227225
csvWriter.writeRow(rowToString(row), printHeader)
228226
printHeader = false
229227
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -159,9 +159,7 @@ private[json] class JsonOutputWriter(
159159
// create the Generator without separator inserted between 2 records
160160
private[this] val gen = new JacksonGenerator(dataSchema, writer, options)
161161

162-
override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal")
163-
164-
override protected[sql] def writeInternal(row: InternalRow): Unit = {
162+
override def write(row: InternalRow): Unit = {
165163
gen.write(row)
166164
gen.writeLineEnding()
167165
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOutputWriter.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,7 @@ private[parquet] class ParquetOutputWriter(path: String, context: TaskAttemptCon
3737
}.getRecordWriter(context)
3838
}
3939

40-
override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal")
41-
42-
override def writeInternal(row: InternalRow): Unit = recordWriter.write(null, row)
40+
override def write(row: InternalRow): Unit = recordWriter.write(null, row)
4341

4442
override def close(): Unit = recordWriter.close(context)
4543
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -132,9 +132,7 @@ class TextOutputWriter(
132132

133133
private val writer = CodecStreams.createOutputStream(context, new Path(path))
134134

135-
override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal")
136-
137-
override protected[sql] def writeInternal(row: InternalRow): Unit = {
135+
override def write(row: InternalRow): Unit = {
138136
if (!row.isNullAt(0)) {
139137
val utf8string = row.getUTF8String(0)
140138
utf8string.writeTo(writer)

sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -239,10 +239,7 @@ private[orc] class OrcOutputWriter(
239239
).asInstanceOf[RecordWriter[NullWritable, Writable]]
240240
}
241241

242-
override def write(row: Row): Unit =
243-
throw new UnsupportedOperationException("call writeInternal")
244-
245-
override protected[sql] def writeInternal(row: InternalRow): Unit = {
242+
override def write(row: InternalRow): Unit = {
246243
recordWriter.write(NullWritable.get(), serializer.serialize(row))
247244
}
248245

0 commit comments

Comments
 (0)