Skip to content

Commit 9767041

Browse files
kevincmchencloud-fan
authored andcommitted
[SPARK-34432][SQL][TESTS] Add JavaSimpleWritableDataSource
### What changes were proposed in this pull request? This is a followup of #19269 In #19269 , there is only a scala implementation of simple writable data source in `DataSourceV2Suite`. This PR adds a java implementation of it. ### Why are the changes needed? To improve test coverage. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? existing testsuites Closes #31560 from kevincmchen/SPARK-34432. Lead-authored-by: kevincmchen <[email protected]> Co-authored-by: Kevin Pis <[email protected]> Co-authored-by: Kevin Pis <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 23a5996 commit 9767041

File tree

3 files changed

+375
-8
lines changed

3 files changed

+375
-8
lines changed

sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSimpleBatchTable.java

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package test.org.apache.spark.sql.connector;
1919

20-
import java.util.Arrays;
20+
import java.util.Collections;
2121
import java.util.HashSet;
2222
import java.util.Set;
2323

@@ -28,11 +28,8 @@
2828
import org.apache.spark.sql.types.StructType;
2929

3030
abstract class JavaSimpleBatchTable implements Table, SupportsRead {
31-
private static final Set<TableCapability> CAPABILITIES = new HashSet<>(Arrays.asList(
32-
TableCapability.BATCH_READ,
33-
TableCapability.BATCH_WRITE,
34-
TableCapability.TRUNCATE));
35-
31+
private static final Set<TableCapability> CAPABILITIES =
32+
new HashSet<>(Collections.singletonList(TableCapability.BATCH_READ));
3633
@Override
3734
public StructType schema() {
3835
return TestingV2Source.schema();
Lines changed: 371 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,371 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package test.org.apache.spark.sql.connector;
19+
20+
import java.io.BufferedReader;
21+
import java.io.IOException;
22+
import java.io.InputStreamReader;
23+
import java.util.Arrays;
24+
import java.util.HashSet;
25+
import java.util.Iterator;
26+
import java.util.Set;
27+
28+
import org.apache.hadoop.conf.Configuration;
29+
import org.apache.hadoop.fs.*;
30+
31+
import org.apache.spark.deploy.SparkHadoopUtil;
32+
import org.apache.spark.sql.catalyst.InternalRow;
33+
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow;
34+
import org.apache.spark.sql.connector.SimpleCounter;
35+
import org.apache.spark.sql.connector.TestingV2Source;
36+
import org.apache.spark.sql.connector.catalog.SessionConfigSupport;
37+
import org.apache.spark.sql.connector.catalog.SupportsWrite;
38+
import org.apache.spark.sql.connector.catalog.Table;
39+
import org.apache.spark.sql.connector.catalog.TableCapability;
40+
import org.apache.spark.sql.connector.read.InputPartition;
41+
import org.apache.spark.sql.connector.read.PartitionReader;
42+
import org.apache.spark.sql.connector.read.PartitionReaderFactory;
43+
import org.apache.spark.sql.connector.read.ScanBuilder;
44+
import org.apache.spark.sql.connector.write.*;
45+
import org.apache.spark.sql.util.CaseInsensitiveStringMap;
46+
import org.apache.spark.util.SerializableConfiguration;
47+
48+
/**
49+
* A HDFS based transactional writable data source which is implemented by java.
50+
* Each task writes data to `target/_temporary/uniqueId/$jobId-$partitionId-$attemptNumber`.
51+
* Each job moves files from `target/_temporary/uniqueId/` to `target`.
52+
*/
53+
public class JavaSimpleWritableDataSource implements TestingV2Source, SessionConfigSupport {
54+
55+
@Override
56+
public String keyPrefix() {
57+
return "javaSimpleWritableDataSource";
58+
}
59+
60+
static class MyScanBuilder extends JavaSimpleScanBuilder {
61+
62+
private final String path;
63+
private final Configuration conf;
64+
65+
MyScanBuilder(String path, Configuration conf) {
66+
this.path = path;
67+
this.conf = conf;
68+
}
69+
70+
@Override
71+
public InputPartition[] planInputPartitions() {
72+
Path dataPath = new Path(this.path);
73+
try {
74+
FileSystem fs = dataPath.getFileSystem(conf);
75+
if (fs.exists(dataPath)) {
76+
return Arrays.stream(fs.listStatus(dataPath))
77+
.filter(
78+
status -> {
79+
String name = status.getPath().getName();
80+
return !name.startsWith("_") && !name.startsWith(".");
81+
})
82+
.map(f -> new JavaCSVInputPartitionReader(f.getPath().toUri().toString()))
83+
.toArray(InputPartition[]::new);
84+
} else {
85+
return new InputPartition[0];
86+
}
87+
} catch (IOException e) {
88+
throw new RuntimeException(e);
89+
}
90+
}
91+
92+
@Override
93+
public PartitionReaderFactory createReaderFactory() {
94+
SerializableConfiguration serializableConf = new SerializableConfiguration(conf);
95+
return new JavaCSVReaderFactory(serializableConf);
96+
}
97+
}
98+
99+
static class MyWriteBuilder implements WriteBuilder, SupportsTruncate {
100+
101+
private final String path;
102+
private final String queryId;
103+
private boolean needTruncate = false;
104+
105+
MyWriteBuilder(String path, LogicalWriteInfo info) {
106+
this.path = path;
107+
this.queryId = info.queryId();
108+
}
109+
110+
@Override
111+
public WriteBuilder truncate() {
112+
this.needTruncate = true;
113+
return this;
114+
}
115+
116+
@Override
117+
public Write build() {
118+
return new MyWrite(path, queryId, needTruncate);
119+
}
120+
}
121+
122+
static class MyWrite implements Write {
123+
124+
private final String path;
125+
private final String queryId;
126+
private final boolean needTruncate;
127+
128+
MyWrite(String path, String queryId, boolean needTruncate) {
129+
this.path = path;
130+
this.queryId = queryId;
131+
this.needTruncate = needTruncate;
132+
}
133+
134+
@Override
135+
public BatchWrite toBatch() {
136+
Path hadoopPath = new Path(path);
137+
Configuration hadoopConf = SparkHadoopUtil.get().conf();
138+
try {
139+
FileSystem fs = hadoopPath.getFileSystem(hadoopConf);
140+
if (needTruncate) {
141+
fs.delete(hadoopPath, true);
142+
}
143+
} catch (IOException e) {
144+
throw new RuntimeException(e);
145+
}
146+
String pathStr = hadoopPath.toUri().toString();
147+
return new MyBatchWrite(queryId, pathStr, hadoopConf);
148+
}
149+
}
150+
151+
static class MyBatchWrite implements BatchWrite {
152+
153+
private final String queryId;
154+
private final String path;
155+
private final Configuration conf;
156+
157+
MyBatchWrite(String queryId, String path, Configuration conf) {
158+
this.queryId = queryId;
159+
this.path = path;
160+
this.conf = conf;
161+
}
162+
163+
@Override
164+
public DataWriterFactory createBatchWriterFactory(PhysicalWriteInfo info) {
165+
SimpleCounter.resetCounter();
166+
return new JavaCSVDataWriterFactory(path, queryId, new SerializableConfiguration(conf));
167+
}
168+
169+
@Override
170+
public void onDataWriterCommit(WriterCommitMessage message) {
171+
SimpleCounter.increaseCounter();
172+
}
173+
174+
@Override
175+
public void commit(WriterCommitMessage[] messages) {
176+
Path finalPath = new Path(this.path);
177+
Path jobPath = new Path(new Path(finalPath, "_temporary"), queryId);
178+
try {
179+
FileSystem fs = jobPath.getFileSystem(conf);
180+
FileStatus[] fileStatuses = fs.listStatus(jobPath);
181+
try {
182+
for (FileStatus status : fileStatuses) {
183+
Path file = status.getPath();
184+
Path dest = new Path(finalPath, file.getName());
185+
if (!fs.rename(file, dest)) {
186+
throw new IOException(String.format("failed to rename(%s, %s)", file, dest));
187+
}
188+
}
189+
} finally {
190+
fs.delete(jobPath, true);
191+
}
192+
} catch (IOException e) {
193+
throw new RuntimeException(e);
194+
}
195+
}
196+
197+
@Override
198+
public void abort(WriterCommitMessage[] messages) {
199+
try {
200+
Path jobPath = new Path(new Path(this.path, "_temporary"), queryId);
201+
FileSystem fs = jobPath.getFileSystem(conf);
202+
fs.delete(jobPath, true);
203+
} catch (IOException e) {
204+
throw new RuntimeException(e);
205+
}
206+
}
207+
}
208+
209+
static class MyTable extends JavaSimpleBatchTable implements SupportsWrite {
210+
211+
private final String path;
212+
private final Configuration conf = SparkHadoopUtil.get().conf();
213+
214+
MyTable(CaseInsensitiveStringMap options) {
215+
this.path = options.get("path");
216+
}
217+
218+
@Override
219+
public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) {
220+
return new MyScanBuilder(new Path(path).toUri().toString(), conf);
221+
}
222+
223+
@Override
224+
public WriteBuilder newWriteBuilder(LogicalWriteInfo info) {
225+
return new MyWriteBuilder(path, info);
226+
}
227+
228+
@Override
229+
public Set<TableCapability> capabilities() {
230+
return new HashSet<>(Arrays.asList(
231+
TableCapability.BATCH_READ,
232+
TableCapability.BATCH_WRITE,
233+
TableCapability.TRUNCATE));
234+
}
235+
}
236+
237+
@Override
238+
public Table getTable(CaseInsensitiveStringMap options) {
239+
return new MyTable(options);
240+
}
241+
242+
static class JavaCSVInputPartitionReader implements InputPartition {
243+
244+
private String path;
245+
246+
JavaCSVInputPartitionReader(String path) {
247+
this.path = path;
248+
}
249+
250+
public String getPath() {
251+
return path;
252+
}
253+
254+
public void setPath(String path) {
255+
this.path = path;
256+
}
257+
}
258+
259+
static class JavaCSVReaderFactory implements PartitionReaderFactory {
260+
261+
private final SerializableConfiguration conf;
262+
263+
JavaCSVReaderFactory(SerializableConfiguration conf) {
264+
this.conf = conf;
265+
}
266+
267+
@Override
268+
public PartitionReader<InternalRow> createReader(InputPartition partition) {
269+
String path = ((JavaCSVInputPartitionReader) partition).getPath();
270+
Path filePath = new Path(path);
271+
try {
272+
FileSystem fs = filePath.getFileSystem(conf.value());
273+
return new PartitionReader<InternalRow>() {
274+
private final FSDataInputStream inputStream = fs.open(filePath);
275+
private final Iterator<String> lines =
276+
new BufferedReader(new InputStreamReader(inputStream)).lines().iterator();
277+
private String currentLine = "";
278+
279+
@Override
280+
public boolean next() {
281+
if (lines.hasNext()) {
282+
currentLine = lines.next();
283+
return true;
284+
} else {
285+
return false;
286+
}
287+
}
288+
289+
@Override
290+
public InternalRow get() {
291+
Object[] objects =
292+
Arrays.stream(currentLine.split(","))
293+
.map(String::trim)
294+
.map(Integer::parseInt)
295+
.toArray();
296+
return new GenericInternalRow(objects);
297+
}
298+
299+
@Override
300+
public void close() throws IOException {
301+
inputStream.close();
302+
}
303+
};
304+
} catch (IOException e) {
305+
throw new RuntimeException(e);
306+
}
307+
}
308+
}
309+
310+
static class JavaCSVDataWriterFactory implements DataWriterFactory {
311+
312+
private final String path;
313+
private final String jobId;
314+
private final SerializableConfiguration conf;
315+
316+
JavaCSVDataWriterFactory(String path, String jobId, SerializableConfiguration conf) {
317+
this.path = path;
318+
this.jobId = jobId;
319+
this.conf = conf;
320+
}
321+
322+
@Override
323+
public DataWriter<InternalRow> createWriter(int partitionId, long taskId) {
324+
try {
325+
Path jobPath = new Path(new Path(path, "_temporary"), jobId);
326+
Path filePath = new Path(jobPath, String.format("%s-%d-%d", jobId, partitionId, taskId));
327+
FileSystem fs = filePath.getFileSystem(conf.value());
328+
return new JavaCSVDataWriter(fs, filePath);
329+
} catch (IOException e) {
330+
throw new RuntimeException(e);
331+
}
332+
}
333+
}
334+
335+
static class JavaCSVDataWriter implements DataWriter<InternalRow> {
336+
337+
private final FileSystem fs;
338+
private final Path file;
339+
private final FSDataOutputStream out;
340+
341+
JavaCSVDataWriter(FileSystem fs, Path file) throws IOException {
342+
this.fs = fs;
343+
this.file = file;
344+
out = fs.create(file);
345+
}
346+
347+
@Override
348+
public void write(InternalRow record) throws IOException {
349+
out.writeBytes(String.format("%d,%d\n", record.getInt(0), record.getInt(1)));
350+
}
351+
352+
@Override
353+
public WriterCommitMessage commit() throws IOException {
354+
out.close();
355+
return null;
356+
}
357+
358+
@Override
359+
public void abort() throws IOException {
360+
try {
361+
out.close();
362+
} finally {
363+
fs.delete(file, false);
364+
}
365+
}
366+
367+
@Override
368+
public void close() {
369+
}
370+
}
371+
}

0 commit comments

Comments
 (0)