From 3b6768dc7be3f986f8ea52f8498614571a5b5c43 Mon Sep 17 00:00:00 2001 From: Daniel Hegberg Date: Tue, 22 Oct 2024 20:43:43 -0700 Subject: [PATCH 1/3] Use single file write when an extension is present in the path. --- .../src/datasource/file_format/parquet.rs | 170 +++++++++++++----- .../src/datasource/file_format/write/demux.rs | 13 +- datafusion/core/src/datasource/listing/url.rs | 63 +++++++ 3 files changed, 199 insertions(+), 47 deletions(-) diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index 8647b5df90be..921172cea520 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -2246,47 +2246,7 @@ mod tests { #[tokio::test] async fn parquet_sink_write() -> Result<()> { - let field_a = Field::new("a", DataType::Utf8, false); - let field_b = Field::new("b", DataType::Utf8, false); - let schema = Arc::new(Schema::new(vec![field_a, field_b])); - let object_store_url = ObjectStoreUrl::local_filesystem(); - - let file_sink_config = FileSinkConfig { - object_store_url: object_store_url.clone(), - file_groups: vec![PartitionedFile::new("/tmp".to_string(), 1)], - table_paths: vec![ListingTableUrl::parse("file:///")?], - output_schema: schema.clone(), - table_partition_cols: vec![], - insert_op: InsertOp::Overwrite, - keep_partition_by_columns: false, - }; - let parquet_sink = Arc::new(ParquetSink::new( - file_sink_config, - TableParquetOptions { - key_value_metadata: std::collections::HashMap::from([ - ("my-data".to_string(), Some("stuff".to_string())), - ("my-data-bool-key".to_string(), None), - ]), - ..Default::default() - }, - )); - - // create data - let col_a: ArrayRef = Arc::new(StringArray::from(vec!["foo", "bar"])); - let col_b: ArrayRef = Arc::new(StringArray::from(vec!["baz", "baz"])); - let batch = RecordBatch::try_from_iter(vec![("a", col_a), ("b", col_b)]).unwrap(); - - // write stream - parquet_sink - .write_all( - Box::pin(RecordBatchStreamAdapter::new( - schema, - futures::stream::iter(vec![Ok(batch)]), - )), - &build_ctx(object_store_url.as_ref()), - ) - .await - .unwrap(); + let parquet_sink = create_written_parquet_sink("file:///").await?; // assert written let mut written = parquet_sink.written(); @@ -2338,6 +2298,134 @@ mod tests { Ok(()) } + #[tokio::test] + async fn parquet_sink_write_with_extension() -> Result<()> { + let filename = "test_file.custom_ext"; + let file_path = format!("file:///path/to/{}", filename); + let parquet_sink = create_written_parquet_sink(file_path.as_str()).await?; + + // assert written + let mut written = parquet_sink.written(); + let written = written.drain(); + assert_eq!( + written.len(), + 1, + "expected a single parquet file to be written, instead found {}", + written.len() + ); + + let ( + path, + .. + ) = written.take(1).next().unwrap(); + + let path_parts = path.parts().collect::>(); + assert_eq!(path_parts.len(), 3, "Expected 3 path parts, instead found {}", path_parts.len()); + assert_eq!(path_parts.last().unwrap().as_ref(), filename); + + Ok(()) + } + + #[tokio::test] + async fn parquet_sink_write_with_directory_name() -> Result<()> { + let file_path = "file:///path/to"; + let parquet_sink = create_written_parquet_sink(file_path).await?; + + // assert written + let mut written = parquet_sink.written(); + let written = written.drain(); + assert_eq!( + written.len(), + 1, + "expected a single parquet file to be written, instead found {}", + written.len() + ); + + let ( + path, + .. + ) = written.take(1).next().unwrap(); + + let path_parts = path.parts().collect::>(); + assert_eq!(path_parts.len(), 3, "Expected 3 path parts, instead found {}", path_parts.len()); + assert!(path_parts.last().unwrap().as_ref().ends_with(".parquet")); + + Ok(()) + } + + #[tokio::test] + async fn parquet_sink_write_with_folder_ending() -> Result<()> { + let file_path = "file:///path/to/"; + let parquet_sink = create_written_parquet_sink(file_path).await?; + + // assert written + let mut written = parquet_sink.written(); + let written = written.drain(); + assert_eq!( + written.len(), + 1, + "expected a single parquet file to be written, instead found {}", + written.len() + ); + + let ( + path, + .. + ) = written.take(1).next().unwrap(); + + let path_parts = path.parts().collect::>(); + assert_eq!(path_parts.len(), 3, "Expected 3 path parts, instead found {}", path_parts.len()); + assert!(path_parts.last().unwrap().as_ref().ends_with(".parquet")); + + Ok(()) + } + + async fn create_written_parquet_sink(table_path: &str) -> Result> { + let field_a = Field::new("a", DataType::Utf8, false); + let field_b = Field::new("b", DataType::Utf8, false); + let schema = Arc::new(Schema::new(vec![field_a, field_b])); + let object_store_url = ObjectStoreUrl::local_filesystem(); + + let file_sink_config = FileSinkConfig { + object_store_url: object_store_url.clone(), + file_groups: vec![PartitionedFile::new("/tmp".to_string(), 1)], + table_paths: vec![ListingTableUrl::parse(table_path)?], + output_schema: schema.clone(), + table_partition_cols: vec![], + insert_op: InsertOp::Overwrite, + keep_partition_by_columns: false, + }; + let parquet_sink = Arc::new(ParquetSink::new( + file_sink_config, + TableParquetOptions { + key_value_metadata: std::collections::HashMap::from([ + ("my-data".to_string(), Some("stuff".to_string())), + ("my-data-bool-key".to_string(), None), + ]), + ..Default::default() + }, + )); + + // create data + let col_a: ArrayRef = Arc::new(StringArray::from(vec!["foo", "bar"])); + let col_b: ArrayRef = Arc::new(StringArray::from(vec!["baz", "baz"])); + let batch = RecordBatch::try_from_iter(vec![("a", col_a), ("b", col_b)]).unwrap(); + + // write stream + parquet_sink + .write_all( + Box::pin(RecordBatchStreamAdapter::new( + schema, + futures::stream::iter(vec![Ok(batch)]), + )), + &build_ctx(object_store_url.as_ref()), + ) + .await + .unwrap(); + + Ok(parquet_sink) + } + #[tokio::test] async fn parquet_sink_write_partitions() -> Result<()> { let field_a = Field::new("a", DataType::Utf8, false); diff --git a/datafusion/core/src/datasource/file_format/write/demux.rs b/datafusion/core/src/datasource/file_format/write/demux.rs index 427b28db4030..73c895c1d143 100644 --- a/datafusion/core/src/datasource/file_format/write/demux.rs +++ b/datafusion/core/src/datasource/file_format/write/demux.rs @@ -59,8 +59,9 @@ type DemuxedStreamReceiver = UnboundedReceiver<(Path, RecordBatchReceiver)>; /// which should be contained within the same output file. The outer channel /// is used to send a dynamic number of inner channels, representing a dynamic /// number of total output files. The caller is also responsible to monitor -/// the demux task for errors and abort accordingly. The single_file_output parameter -/// overrides all other settings to force only a single file to be written. +/// the demux task for errors and abort accordingly. A path with an extension will +/// force only a single file to be written with the extension from the path. Otherwise +/// the default extension will be used and the output will be split into multiple files. /// partition_by parameter will additionally split the input based on the unique /// values of a specific column ``` /// ┌───────────┐ ┌────────────┐ ┌─────────────┐ @@ -79,12 +80,12 @@ pub(crate) fn start_demuxer_task( context: &Arc, partition_by: Option>, base_output_path: ListingTableUrl, - file_extension: String, + default_extension: String, keep_partition_by_columns: bool, ) -> (SpawnedTask>, DemuxedStreamReceiver) { let (tx, rx) = mpsc::unbounded_channel(); let context = context.clone(); - let single_file_output = !base_output_path.is_collection(); + let single_file_output = !base_output_path.is_collection() && base_output_path.file_extension().is_some(); let task = match partition_by { Some(parts) => { // There could be an arbitrarily large number of parallel hive style partitions being written to, so we cannot @@ -96,7 +97,7 @@ pub(crate) fn start_demuxer_task( context, parts, base_output_path, - file_extension, + default_extension, keep_partition_by_columns, ) .await @@ -108,7 +109,7 @@ pub(crate) fn start_demuxer_task( input, context, base_output_path, - file_extension, + default_extension, single_file_output, ) .await diff --git a/datafusion/core/src/datasource/listing/url.rs b/datafusion/core/src/datasource/listing/url.rs index 1701707fdb72..9412f9733dc5 100644 --- a/datafusion/core/src/datasource/listing/url.rs +++ b/datafusion/core/src/datasource/listing/url.rs @@ -190,6 +190,19 @@ impl ListingTableUrl { self.url.path().ends_with(DELIMITER) } + /// Returns the file extension of the last path segment if it exists + pub fn file_extension(&self) -> Option<&str> { + if let Some(segments) = self.url.path_segments() { + if let Some(last_segment) = segments.last() { + if last_segment.contains(".") && !last_segment.ends_with(".") { + return last_segment.split('.').last(); + } + } + } + + return None; + } + /// Strips the prefix of this [`ListingTableUrl`] from the provided path, returning /// an iterator of the remaining path segments pub(crate) fn strip_prefix<'a, 'b: 'a>( @@ -493,4 +506,54 @@ mod tests { "path not ends with / - fragment ends with / - not collection", ); } + + #[test] + fn test_file_extension() { + fn test(input: &str, expected: Option<&str>, message: &str) { + let url = ListingTableUrl::parse(input).unwrap(); + assert_eq!(url.file_extension(), expected, "{message}"); + } + + test("https://a.b.c/path/", None, "path ends with / - not a file"); + test( + "https://a.b.c/path/?a=b", + None, + "path ends with / - with query args - not a file", + ); + test( + "https://a.b.c/path?a=b/", + None, + "path not ends with / - query ends with / but no file extension", + ); + test( + "https://a.b.c/path/#a=b", + None, + "path ends with / - with fragment - not a file", + ); + test( + "https://a.b.c/path#a=b/", + None, + "path not ends with / - fragment ends with / but no file extension", + ); + test( + "file///some/path/", + None, + "file path ends with / - not a file", + ); + test( + "file///some/path/file", + None, + "file path does not end with - no extension", + ); + test( + "file///some/path/file.", + None, + "file path ends with . - no value after .", + ); + test( + "file///some/path/file.ext", + Some("ext"), + "file path ends with .ext - extension is ext", + ); + } } From e5dcb5b070f46d1337bd62badb8525549aa5b3c5 Mon Sep 17 00:00:00 2001 From: Daniel Hegberg Date: Wed, 23 Oct 2024 08:01:43 -0700 Subject: [PATCH 2/3] Adjust formatting. --- .../src/datasource/file_format/parquet.rs | 36 +++++++++++-------- .../src/datasource/file_format/write/demux.rs | 3 +- 2 files changed, 23 insertions(+), 16 deletions(-) diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index 921172cea520..91fdc9ad19db 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -2314,13 +2314,15 @@ mod tests { written.len() ); - let ( - path, - .. - ) = written.take(1).next().unwrap(); + let (path, ..) = written.take(1).next().unwrap(); let path_parts = path.parts().collect::>(); - assert_eq!(path_parts.len(), 3, "Expected 3 path parts, instead found {}", path_parts.len()); + assert_eq!( + path_parts.len(), + 3, + "Expected 3 path parts, instead found {}", + path_parts.len() + ); assert_eq!(path_parts.last().unwrap().as_ref(), filename); Ok(()) @@ -2341,13 +2343,15 @@ mod tests { written.len() ); - let ( - path, - .. - ) = written.take(1).next().unwrap(); + let (path, ..) = written.take(1).next().unwrap(); let path_parts = path.parts().collect::>(); - assert_eq!(path_parts.len(), 3, "Expected 3 path parts, instead found {}", path_parts.len()); + assert_eq!( + path_parts.len(), + 3, + "Expected 3 path parts, instead found {}", + path_parts.len() + ); assert!(path_parts.last().unwrap().as_ref().ends_with(".parquet")); Ok(()) @@ -2368,13 +2372,15 @@ mod tests { written.len() ); - let ( - path, - .. - ) = written.take(1).next().unwrap(); + let (path, ..) = written.take(1).next().unwrap(); let path_parts = path.parts().collect::>(); - assert_eq!(path_parts.len(), 3, "Expected 3 path parts, instead found {}", path_parts.len()); + assert_eq!( + path_parts.len(), + 3, + "Expected 3 path parts, instead found {}", + path_parts.len() + ); assert!(path_parts.last().unwrap().as_ref().ends_with(".parquet")); Ok(()) diff --git a/datafusion/core/src/datasource/file_format/write/demux.rs b/datafusion/core/src/datasource/file_format/write/demux.rs index 73c895c1d143..1ed1d8df7b0a 100644 --- a/datafusion/core/src/datasource/file_format/write/demux.rs +++ b/datafusion/core/src/datasource/file_format/write/demux.rs @@ -85,7 +85,8 @@ pub(crate) fn start_demuxer_task( ) -> (SpawnedTask>, DemuxedStreamReceiver) { let (tx, rx) = mpsc::unbounded_channel(); let context = context.clone(); - let single_file_output = !base_output_path.is_collection() && base_output_path.file_extension().is_some(); + let single_file_output = + !base_output_path.is_collection() && base_output_path.file_extension().is_some(); let task = match partition_by { Some(parts) => { // There could be an arbitrarily large number of parallel hive style partitions being written to, so we cannot From ca7df7266d865f17e01b69f417f0670c76bfdb0c Mon Sep 17 00:00:00 2001 From: Daniel Hegberg Date: Wed, 23 Oct 2024 08:20:49 -0700 Subject: [PATCH 3/3] Remove unneeded return statement.. --- datafusion/core/src/datasource/listing/url.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/core/src/datasource/listing/url.rs b/datafusion/core/src/datasource/listing/url.rs index 9412f9733dc5..e627cacfbfc7 100644 --- a/datafusion/core/src/datasource/listing/url.rs +++ b/datafusion/core/src/datasource/listing/url.rs @@ -200,7 +200,7 @@ impl ListingTableUrl { } } - return None; + None } /// Strips the prefix of this [`ListingTableUrl`] from the provided path, returning