Skip to content

Commit c82f16c

Browse files
Nathan Howellrxin
authored andcommitted
[SPARK-18658][SQL] Write text records directly to a FileOutputStream
## What changes were proposed in this pull request? This replaces uses of `TextOutputFormat` with an `OutputStream`, which will either write directly to the filesystem or indirectly via a compressor (if so configured). This avoids intermediate buffering. The inverse of this (reading directly from a stream) is necessary for streaming large JSON records (when `wholeFile` is enabled) so I wanted to keep the read and write paths symmetric. ## How was this patch tested? Existing unit tests. Author: Nathan Howell <[email protected]> Closes #16089 from NathanHowell/SPARK-18658.
1 parent d3c90b7 commit c82f16c

File tree

10 files changed

+252
-144
lines changed

10 files changed

+252
-144
lines changed

common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,25 @@ public void writeTo(ByteBuffer buffer) {
147147
buffer.position(pos + numBytes);
148148
}
149149

150+
public void writeTo(OutputStream out) throws IOException {
151+
if (base instanceof byte[] && offset >= BYTE_ARRAY_OFFSET) {
152+
final byte[] bytes = (byte[]) base;
153+
154+
// the offset includes an object header... this is only needed for unsafe copies
155+
final long arrayOffset = offset - BYTE_ARRAY_OFFSET;
156+
157+
// verify that the offset and length points somewhere inside the byte array
158+
// and that the offset can safely be truncated to a 32-bit integer
159+
if ((long) bytes.length < arrayOffset + numBytes) {
160+
throw new ArrayIndexOutOfBoundsException();
161+
}
162+
163+
out.write(bytes, (int) arrayOffset, numBytes);
164+
} else {
165+
out.write(getBytes());
166+
}
167+
}
168+
150169
/**
151170
* Returns the number of bytes for a code point with the first byte as `b`
152171
* @param b The first byte of a code point

common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,22 @@
1717

1818
package org.apache.spark.unsafe.types;
1919

20+
import java.io.ByteArrayOutputStream;
21+
import java.io.IOException;
22+
import java.nio.ByteBuffer;
23+
import java.nio.ByteOrder;
2024
import java.nio.charset.StandardCharsets;
2125
import java.util.Arrays;
2226
import java.util.HashMap;
27+
import java.util.HashSet;
2328

2429
import com.google.common.collect.ImmutableMap;
30+
import org.apache.spark.unsafe.Platform;
2531
import org.junit.Test;
2632

2733
import static org.junit.Assert.*;
2834

35+
import static org.apache.spark.unsafe.Platform.BYTE_ARRAY_OFFSET;
2936
import static org.apache.spark.unsafe.types.UTF8String.*;
3037

3138
public class UTF8StringSuite {
@@ -499,4 +506,106 @@ public void soundex() {
499506
assertEquals(fromString("123").soundex(), fromString("123"));
500507
assertEquals(fromString("世界千世").soundex(), fromString("世界千世"));
501508
}
509+
510+
@Test
511+
public void writeToOutputStreamUnderflow() throws IOException {
512+
// offset underflow is apparently supported?
513+
final ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
514+
final byte[] test = "01234567".getBytes(StandardCharsets.UTF_8);
515+
516+
for (int i = 1; i <= Platform.BYTE_ARRAY_OFFSET; ++i) {
517+
UTF8String.fromAddress(test, Platform.BYTE_ARRAY_OFFSET - i, test.length + i)
518+
.writeTo(outputStream);
519+
final ByteBuffer buffer = ByteBuffer.wrap(outputStream.toByteArray(), i, test.length);
520+
assertEquals("01234567", StandardCharsets.UTF_8.decode(buffer).toString());
521+
outputStream.reset();
522+
}
523+
}
524+
525+
@Test
526+
public void writeToOutputStreamSlice() throws IOException {
527+
final ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
528+
final byte[] test = "01234567".getBytes(StandardCharsets.UTF_8);
529+
530+
for (int i = 0; i < test.length; ++i) {
531+
for (int j = 0; j < test.length - i; ++j) {
532+
UTF8String.fromAddress(test, Platform.BYTE_ARRAY_OFFSET + i, j)
533+
.writeTo(outputStream);
534+
535+
assertArrayEquals(Arrays.copyOfRange(test, i, i + j), outputStream.toByteArray());
536+
outputStream.reset();
537+
}
538+
}
539+
}
540+
541+
@Test
542+
public void writeToOutputStreamOverflow() throws IOException {
543+
final ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
544+
final byte[] test = "01234567".getBytes(StandardCharsets.UTF_8);
545+
546+
final HashSet<Long> offsets = new HashSet<>();
547+
for (int i = 0; i < 16; ++i) {
548+
// touch more points around MAX_VALUE
549+
offsets.add((long) Integer.MAX_VALUE - i);
550+
// subtract off BYTE_ARRAY_OFFSET to avoid wrapping around to a negative value,
551+
// which will hit the slower copy path instead of the optimized one
552+
offsets.add(Long.MAX_VALUE - BYTE_ARRAY_OFFSET - i);
553+
}
554+
555+
for (long i = 1; i > 0L; i <<= 1) {
556+
for (long j = 0; j < 32L; ++j) {
557+
offsets.add(i + j);
558+
}
559+
}
560+
561+
for (final long offset : offsets) {
562+
try {
563+
fromAddress(test, BYTE_ARRAY_OFFSET + offset, test.length)
564+
.writeTo(outputStream);
565+
566+
throw new IllegalStateException(Long.toString(offset));
567+
} catch (ArrayIndexOutOfBoundsException e) {
568+
// ignore
569+
} finally {
570+
outputStream.reset();
571+
}
572+
}
573+
}
574+
575+
@Test
576+
public void writeToOutputStream() throws IOException {
577+
final ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
578+
EMPTY_UTF8.writeTo(outputStream);
579+
assertEquals("", outputStream.toString("UTF-8"));
580+
outputStream.reset();
581+
582+
fromString("数据砖很重").writeTo(outputStream);
583+
assertEquals(
584+
"数据砖很重",
585+
outputStream.toString("UTF-8"));
586+
outputStream.reset();
587+
}
588+
589+
@Test
590+
public void writeToOutputStreamIntArray() throws IOException {
591+
// verify that writes work on objects that are not byte arrays
592+
final ByteBuffer buffer = StandardCharsets.UTF_8.encode("大千世界");
593+
buffer.position(0);
594+
buffer.order(ByteOrder.LITTLE_ENDIAN);
595+
596+
final int length = buffer.limit();
597+
assertEquals(12, length);
598+
599+
final int ints = length / 4;
600+
final int[] array = new int[ints];
601+
602+
for (int i = 0; i < ints; ++i) {
603+
array[i] = buffer.getInt();
604+
}
605+
606+
final ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
607+
fromAddress(array, Platform.INT_ARRAY_OFFSET, length)
608+
.writeTo(outputStream);
609+
assertEquals("大千世界", outputStream.toString("UTF-8"));
610+
}
502611
}

mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,7 @@ import java.io.IOException
2121

2222
import org.apache.hadoop.conf.Configuration
2323
import org.apache.hadoop.fs.{FileStatus, Path}
24-
import org.apache.hadoop.io.{NullWritable, Text}
25-
import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext}
26-
import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat
24+
import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext}
2725

2826
import org.apache.spark.TaskContext
2927
import org.apache.spark.ml.feature.LabeledPoint
@@ -35,7 +33,6 @@ import org.apache.spark.sql.catalyst.encoders.RowEncoder
3533
import org.apache.spark.sql.catalyst.expressions.AttributeReference
3634
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
3735
import org.apache.spark.sql.execution.datasources._
38-
import org.apache.spark.sql.execution.datasources.text.TextOutputWriter
3936
import org.apache.spark.sql.sources._
4037
import org.apache.spark.sql.types._
4138
import org.apache.spark.util.SerializableConfiguration
@@ -46,30 +43,21 @@ private[libsvm] class LibSVMOutputWriter(
4643
context: TaskAttemptContext)
4744
extends OutputWriter {
4845

49-
private[this] val buffer = new Text()
50-
51-
private val recordWriter: RecordWriter[NullWritable, Text] = {
52-
new TextOutputFormat[NullWritable, Text]() {
53-
override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = {
54-
new Path(path)
55-
}
56-
}.getRecordWriter(context)
57-
}
46+
private val writer = CodecStreams.createOutputStreamWriter(context, new Path(path))
5847

5948
override def write(row: Row): Unit = {
6049
val label = row.get(0)
6150
val vector = row.get(1).asInstanceOf[Vector]
62-
val sb = new StringBuilder(label.toString)
51+
writer.write(label.toString)
6352
vector.foreachActive { case (i, v) =>
64-
sb += ' '
65-
sb ++= s"${i + 1}:$v"
53+
writer.write(s" ${i + 1}:$v")
6654
}
67-
buffer.set(sb.mkString)
68-
recordWriter.write(NullWritable.get(), buffer)
55+
56+
writer.write('\n')
6957
}
7058

7159
override def close(): Unit = {
72-
recordWriter.close(context)
60+
writer.close()
7361
}
7462
}
7563

@@ -136,7 +124,7 @@ private[libsvm] class LibSVMFileFormat extends TextBasedFileFormat with DataSour
136124
}
137125

138126
override def getFileExtension(context: TaskAttemptContext): String = {
139-
".libsvm" + TextOutputWriter.getCompressionExtension(context)
127+
".libsvm" + CodecStreams.getCompressionExtension(context)
140128
}
141129
}
142130
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,4 +194,8 @@ private[sql] class JacksonGenerator(
194194
writeFields(row, schema, rootFieldWriters)
195195
}
196196
}
197+
198+
def writeLineEnding(): Unit = {
199+
gen.writeRaw('\n')
200+
}
197201
}
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.datasources
19+
20+
import java.io.{OutputStream, OutputStreamWriter}
21+
import java.nio.charset.{Charset, StandardCharsets}
22+
23+
import org.apache.hadoop.fs.Path
24+
import org.apache.hadoop.io.compress._
25+
import org.apache.hadoop.mapreduce.JobContext
26+
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat
27+
import org.apache.hadoop.util.ReflectionUtils
28+
29+
object CodecStreams {
30+
private def getCompressionCodec(
31+
context: JobContext,
32+
file: Option[Path] = None): Option[CompressionCodec] = {
33+
if (FileOutputFormat.getCompressOutput(context)) {
34+
val compressorClass = FileOutputFormat.getOutputCompressorClass(
35+
context,
36+
classOf[GzipCodec])
37+
38+
Some(ReflectionUtils.newInstance(compressorClass, context.getConfiguration))
39+
} else {
40+
file.flatMap { path =>
41+
val compressionCodecs = new CompressionCodecFactory(context.getConfiguration)
42+
Option(compressionCodecs.getCodec(path))
43+
}
44+
}
45+
}
46+
47+
/**
48+
* Create a new file and open it for writing.
49+
* If compression is enabled in the [[JobContext]] the stream will write compressed data to disk.
50+
* An exception will be thrown if the file already exists.
51+
*/
52+
def createOutputStream(context: JobContext, file: Path): OutputStream = {
53+
val fs = file.getFileSystem(context.getConfiguration)
54+
val outputStream: OutputStream = fs.create(file, false)
55+
56+
getCompressionCodec(context, Some(file))
57+
.map(codec => codec.createOutputStream(outputStream))
58+
.getOrElse(outputStream)
59+
}
60+
61+
def createOutputStreamWriter(
62+
context: JobContext,
63+
file: Path,
64+
charset: Charset = StandardCharsets.UTF_8): OutputStreamWriter = {
65+
new OutputStreamWriter(createOutputStream(context, file), charset)
66+
}
67+
68+
/** Returns the compression codec extension to be used in a file name, e.g. ".gzip"). */
69+
def getCompressionExtension(context: JobContext): String = {
70+
getCompressionCodec(context)
71+
.map(_.getDefaultExtension)
72+
.getOrElse("")
73+
}
74+
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717

1818
package org.apache.spark.sql.execution.datasources.csv
1919

20-
import java.io.{CharArrayWriter, StringReader}
20+
import java.io.{CharArrayWriter, OutputStream, StringReader}
21+
import java.nio.charset.StandardCharsets
2122

2223
import com.univocity.parsers.csv._
2324

@@ -64,7 +65,10 @@ private[csv] class CsvReader(params: CSVOptions) {
6465
* @param params Parameters object for configuration
6566
* @param headers headers for columns
6667
*/
67-
private[csv] class LineCsvWriter(params: CSVOptions, headers: Seq[String]) extends Logging {
68+
private[csv] class LineCsvWriter(
69+
params: CSVOptions,
70+
headers: Seq[String],
71+
output: OutputStream) extends Logging {
6872
private val writerSettings = new CsvWriterSettings
6973
private val format = writerSettings.getFormat
7074

@@ -80,21 +84,14 @@ private[csv] class LineCsvWriter(params: CSVOptions, headers: Seq[String]) exten
8084
writerSettings.setHeaders(headers: _*)
8185
writerSettings.setQuoteEscapingEnabled(params.escapeQuotes)
8286

83-
private val buffer = new CharArrayWriter()
84-
private val writer = new CsvWriter(buffer, writerSettings)
87+
private val writer = new CsvWriter(output, StandardCharsets.UTF_8, writerSettings)
8588

8689
def writeRow(row: Seq[String], includeHeader: Boolean): Unit = {
8790
if (includeHeader) {
8891
writer.writeHeaders()
8992
}
90-
writer.writeRow(row.toArray: _*)
91-
}
9293

93-
def flush(): String = {
94-
writer.flush()
95-
val lines = buffer.toString.stripLineEnd
96-
buffer.reset()
97-
lines
94+
writer.writeRow(row: _*)
9895
}
9996

10097
def close(): Unit = {

0 commit comments

Comments
 (0)