Skip to content

Commit 56667bd

Browse files
author
Nathan Howell
committed
[SPARK-18658][SQL] Write text records directly to a FileOutputStream
1 parent c51c772 commit 56667bd

File tree

8 files changed

+120
-143
lines changed

8 files changed

+120
-143
lines changed

mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,7 @@ import java.io.IOException
2121

2222
import org.apache.hadoop.conf.Configuration
2323
import org.apache.hadoop.fs.{FileStatus, Path}
24-
import org.apache.hadoop.io.{NullWritable, Text}
25-
import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext}
26-
import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat
24+
import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext}
2725

2826
import org.apache.spark.TaskContext
2927
import org.apache.spark.ml.feature.LabeledPoint
@@ -35,7 +33,6 @@ import org.apache.spark.sql.catalyst.encoders.RowEncoder
3533
import org.apache.spark.sql.catalyst.expressions.AttributeReference
3634
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
3735
import org.apache.spark.sql.execution.datasources._
38-
import org.apache.spark.sql.execution.datasources.text.TextOutputWriter
3936
import org.apache.spark.sql.sources._
4037
import org.apache.spark.sql.types._
4138
import org.apache.spark.util.SerializableConfiguration
@@ -46,30 +43,21 @@ private[libsvm] class LibSVMOutputWriter(
4643
context: TaskAttemptContext)
4744
extends OutputWriter {
4845

49-
private[this] val buffer = new Text()
50-
51-
private val recordWriter: RecordWriter[NullWritable, Text] = {
52-
new TextOutputFormat[NullWritable, Text]() {
53-
override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = {
54-
new Path(path)
55-
}
56-
}.getRecordWriter(context)
57-
}
46+
private val writer = CodecStreams.getOutputStreamWriter(context, new Path(path))
5847

5948
override def write(row: Row): Unit = {
6049
val label = row.get(0)
6150
val vector = row.get(1).asInstanceOf[Vector]
62-
val sb = new StringBuilder(label.toString)
51+
writer.write(label.toString)
6352
vector.foreachActive { case (i, v) =>
64-
sb += ' '
65-
sb ++= s"${i + 1}:$v"
53+
writer.write(s" ${i + 1}:$v")
6654
}
67-
buffer.set(sb.mkString)
68-
recordWriter.write(NullWritable.get(), buffer)
55+
56+
writer.write('\n')
6957
}
7058

7159
override def close(): Unit = {
72-
recordWriter.close(context)
60+
writer.close()
7361
}
7462
}
7563

@@ -136,7 +124,7 @@ private[libsvm] class LibSVMFileFormat extends TextBasedFileFormat with DataSour
136124
}
137125

138126
override def getFileExtension(context: TaskAttemptContext): String = {
139-
".libsvm" + TextOutputWriter.getCompressionExtension(context)
127+
".libsvm" + CodecStreams.getCompressionExtension(context)
140128
}
141129
}
142130
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,4 +194,8 @@ private[sql] class JacksonGenerator(
194194
writeFields(row, schema, rootFieldWriters)
195195
}
196196
}
197+
198+
private[sql] def writeLineEnding(): Unit = {
199+
gen.writeRaw('\n')
200+
}
197201
}
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.datasources
19+
20+
import java.io.{OutputStream, OutputStreamWriter}
21+
import java.nio.charset.{Charset, StandardCharsets}
22+
23+
import org.apache.hadoop.fs.Path
24+
import org.apache.hadoop.io.compress._
25+
import org.apache.hadoop.mapreduce.JobContext
26+
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat
27+
import org.apache.hadoop.util.ReflectionUtils
28+
29+
object CodecStreams {
30+
private def getCompressionCodec(
31+
context: JobContext,
32+
file: Option[Path] = None): Option[CompressionCodec] = {
33+
if (FileOutputFormat.getCompressOutput(context)) {
34+
val compressorClass = FileOutputFormat.getOutputCompressorClass(
35+
context,
36+
classOf[GzipCodec])
37+
38+
Some(ReflectionUtils.newInstance(compressorClass, context.getConfiguration))
39+
} else {
40+
file.flatMap { path =>
41+
val compressionCodecs = new CompressionCodecFactory(context.getConfiguration)
42+
Option(compressionCodecs.getCodec(path))
43+
}
44+
}
45+
}
46+
47+
/** Create a new file and open it for writing.
48+
* If compression is enabled in the [[JobContext]] the stream will write compressed data to disk.
49+
* An exception will be thrown if the file already exists.
50+
*/
51+
def getOutputStream(context: JobContext, file: Path): OutputStream = {
52+
val fs = file.getFileSystem(context.getConfiguration)
53+
val outputStream: OutputStream = fs.create(file, false)
54+
55+
getCompressionCodec(context, Some(file)).fold(outputStream) { codec =>
56+
codec.createOutputStream(outputStream)
57+
}
58+
}
59+
60+
def getOutputStreamWriter(
61+
context: JobContext,
62+
file: Path,
63+
charset: Charset = StandardCharsets.UTF_8): OutputStreamWriter = {
64+
new OutputStreamWriter(getOutputStream(context, file), charset)
65+
}
66+
67+
/** Returns the compression codec extension to be used in a file name, e.g. ".gzip"). */
68+
def getCompressionExtension(context: JobContext): String = {
69+
getCompressionCodec(context).fold("") { code =>
70+
code.getDefaultExtension
71+
}
72+
}
73+
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717

1818
package org.apache.spark.sql.execution.datasources.csv
1919

20-
import java.io.{CharArrayWriter, StringReader}
20+
import java.io.{CharArrayWriter, OutputStream, StringReader}
21+
import java.nio.charset.StandardCharsets
2122

2223
import com.univocity.parsers.csv._
2324

@@ -64,7 +65,10 @@ private[csv] class CsvReader(params: CSVOptions) {
6465
* @param params Parameters object for configuration
6566
* @param headers headers for columns
6667
*/
67-
private[csv] class LineCsvWriter(params: CSVOptions, headers: Seq[String]) extends Logging {
68+
private[csv] class LineCsvWriter(
69+
params: CSVOptions,
70+
headers: Seq[String],
71+
output: OutputStream) extends Logging {
6872
private val writerSettings = new CsvWriterSettings
6973
private val format = writerSettings.getFormat
7074

@@ -80,21 +84,14 @@ private[csv] class LineCsvWriter(params: CSVOptions, headers: Seq[String]) exten
8084
writerSettings.setHeaders(headers: _*)
8185
writerSettings.setQuoteEscapingEnabled(params.escapeQuotes)
8286

83-
private val buffer = new CharArrayWriter()
84-
private val writer = new CsvWriter(buffer, writerSettings)
87+
private val writer = new CsvWriter(output, StandardCharsets.UTF_8, writerSettings)
8588

8689
def writeRow(row: Seq[String], includeHeader: Boolean): Unit = {
8790
if (includeHeader) {
8891
writer.writeHeaders()
8992
}
90-
writer.writeRow(row.toArray: _*)
91-
}
9293

93-
def flush(): String = {
94-
writer.flush()
95-
val lines = buffer.toString.stripLineEnd
96-
buffer.reset()
97-
lines
94+
writer.writeRow(row.toArray: _*)
9895
}
9996

10097
def close(): Unit = {

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala

Lines changed: 8 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,15 @@ package org.apache.spark.sql.execution.datasources.csv
2020
import scala.util.control.NonFatal
2121

2222
import org.apache.hadoop.fs.Path
23-
import org.apache.hadoop.io.{NullWritable, Text}
24-
import org.apache.hadoop.mapreduce.RecordWriter
2523
import org.apache.hadoop.mapreduce.TaskAttemptContext
26-
import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat
2724

2825
import org.apache.spark.internal.Logging
2926
import org.apache.spark.rdd.RDD
3027
import org.apache.spark.sql._
3128
import org.apache.spark.sql.catalyst.InternalRow
3229
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
3330
import org.apache.spark.sql.catalyst.util.DateTimeUtils
34-
import org.apache.spark.sql.execution.datasources.{OutputWriter, OutputWriterFactory, PartitionedFile}
35-
import org.apache.spark.sql.execution.datasources.text.TextOutputWriter
31+
import org.apache.spark.sql.execution.datasources.{CodecStreams, OutputWriter, OutputWriterFactory, PartitionedFile}
3632
import org.apache.spark.sql.types._
3733

3834
object CSVRelation extends Logging {
@@ -179,7 +175,7 @@ private[csv] class CSVOutputWriterFactory(params: CSVOptions) extends OutputWrit
179175
}
180176

181177
override def getFileExtension(context: TaskAttemptContext): String = {
182-
".csv" + TextOutputWriter.getCompressionExtension(context)
178+
".csv" + CodecStreams.getCompressionExtension(context)
183179
}
184180
}
185181

@@ -189,9 +185,6 @@ private[csv] class CsvOutputWriter(
189185
context: TaskAttemptContext,
190186
params: CSVOptions) extends OutputWriter with Logging {
191187

192-
// create the Generator without separator inserted between 2 records
193-
private[this] val text = new Text()
194-
195188
// A `ValueConverter` is responsible for converting a value of an `InternalRow` to `String`.
196189
// When the value is null, this converter should not be called.
197190
private type ValueConverter = (InternalRow, Int) => String
@@ -200,17 +193,9 @@ private[csv] class CsvOutputWriter(
200193
private val valueConverters: Array[ValueConverter] =
201194
dataSchema.map(_.dataType).map(makeConverter).toArray
202195

203-
private val recordWriter: RecordWriter[NullWritable, Text] = {
204-
new TextOutputFormat[NullWritable, Text]() {
205-
override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = {
206-
new Path(path)
207-
}
208-
}.getRecordWriter(context)
209-
}
210-
211-
private val FLUSH_BATCH_SIZE = 1024L
212-
private var records: Long = 0L
213-
private val csvWriter = new LineCsvWriter(params, dataSchema.fieldNames.toSeq)
196+
private var printHeader: Boolean = params.headerFlag
197+
private val writer = CodecStreams.getOutputStream(context, new Path(path))
198+
private val csvWriter = new LineCsvWriter(params, dataSchema.fieldNames.toSeq, writer)
214199

215200
private def rowToString(row: InternalRow): Seq[String] = {
216201
var i = 0
@@ -245,24 +230,12 @@ private[csv] class CsvOutputWriter(
245230
override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal")
246231

247232
override protected[sql] def writeInternal(row: InternalRow): Unit = {
248-
csvWriter.writeRow(rowToString(row), records == 0L && params.headerFlag)
249-
records += 1
250-
if (records % FLUSH_BATCH_SIZE == 0) {
251-
flush()
252-
}
253-
}
254-
255-
private def flush(): Unit = {
256-
val lines = csvWriter.flush()
257-
if (lines.nonEmpty) {
258-
text.set(lines)
259-
recordWriter.write(NullWritable.get(), text)
260-
}
233+
csvWriter.writeRow(rowToString(row), printHeader)
234+
printHeader = false
261235
}
262236

263237
override def close(): Unit = {
264-
flush()
265238
csvWriter.close()
266-
recordWriter.close(context)
239+
writer.close()
267240
}
268241
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala

Lines changed: 7 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,12 @@
1717

1818
package org.apache.spark.sql.execution.datasources.json
1919

20-
import java.io.CharArrayWriter
21-
2220
import org.apache.hadoop.conf.Configuration
2321
import org.apache.hadoop.fs.{FileStatus, Path}
24-
import org.apache.hadoop.io.{LongWritable, NullWritable, Text}
22+
import org.apache.hadoop.io.{LongWritable, Text}
2523
import org.apache.hadoop.mapred.{JobConf, TextInputFormat}
26-
import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext}
24+
import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext}
2725
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat
28-
import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat
2926

3027
import org.apache.spark.TaskContext
3128
import org.apache.spark.internal.Logging
@@ -35,7 +32,6 @@ import org.apache.spark.sql.catalyst.InternalRow
3532
import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JacksonParser, JSONOptions}
3633
import org.apache.spark.sql.catalyst.util.CompressionCodecs
3734
import org.apache.spark.sql.execution.datasources._
38-
import org.apache.spark.sql.execution.datasources.text.TextOutputWriter
3935
import org.apache.spark.sql.sources._
4036
import org.apache.spark.sql.types.StructType
4137
import org.apache.spark.util.SerializableConfiguration
@@ -90,7 +86,7 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister {
9086
}
9187

9288
override def getFileExtension(context: TaskAttemptContext): String = {
93-
".json" + TextOutputWriter.getCompressionExtension(context)
89+
".json" + CodecStreams.getCompressionExtension(context)
9490
}
9591
}
9692
}
@@ -163,33 +159,20 @@ private[json] class JsonOutputWriter(
163159
context: TaskAttemptContext)
164160
extends OutputWriter with Logging {
165161

166-
private[this] val writer = new CharArrayWriter()
162+
private val writer = CodecStreams.getOutputStreamWriter(context, new Path(path))
163+
167164
// create the Generator without separator inserted between 2 records
168165
private[this] val gen = new JacksonGenerator(dataSchema, writer, options)
169-
private[this] val result = new Text()
170-
171-
private val recordWriter: RecordWriter[NullWritable, Text] = {
172-
new TextOutputFormat[NullWritable, Text]() {
173-
override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = {
174-
new Path(path)
175-
}
176-
}.getRecordWriter(context)
177-
}
178166

179167
override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal")
180168

181169
override protected[sql] def writeInternal(row: InternalRow): Unit = {
182170
gen.write(row)
183-
gen.flush()
184-
185-
result.set(writer.toString)
186-
writer.reset()
187-
188-
recordWriter.write(NullWritable.get(), result)
171+
gen.writeLineEnding()
189172
}
190173

191174
override def close(): Unit = {
192175
gen.close()
193-
recordWriter.close(context)
176+
writer.close()
194177
}
195178
}

0 commit comments

Comments
 (0)