From 8f5ed6b2f50a75ab8a17859bacad038c70945933 Mon Sep 17 00:00:00 2001 From: osipovartem Date: Wed, 15 Jan 2025 11:35:08 +0300 Subject: [PATCH 1/4] Extend references with new SHowSchemas type Make resolve_table_ref as pub Fix docs linter Add DDL AlterTable Add DDL AlterTable Add "rlike" as an alias for regexp_like (#2) * Add "rlike" as an alias for regexp_like * Update docs Extend references with new ShowSchemas type (#4) * Extend references with new SHowSchemas type * Make resolve_table_ref as pub * Fix docs linter Fix deps Import alter table Temp Temp Add regexp_substr udf Add regexp_substr udf Add regexp_substr udf Add regexp_substr udf --- datafusion/core/src/catalog_common/mod.rs | 2 + .../core/src/execution/session_state.rs | 10 +- datafusion/functions/src/regex/mod.rs | 27 + datafusion/functions/src/regex/regexplike.rs | 6 + .../functions/src/regex/regexpsubstr.rs | 532 ++++++++++++++++++ .../source/user-guide/sql/scalar_functions.md | 9 + 6 files changed, 582 insertions(+), 4 deletions(-) create mode 100644 datafusion/functions/src/regex/regexpsubstr.rs diff --git a/datafusion/core/src/catalog_common/mod.rs b/datafusion/core/src/catalog_common/mod.rs index 68c78dda4899..9827f259bd13 100644 --- a/datafusion/core/src/catalog_common/mod.rs +++ b/datafusion/core/src/catalog_common/mod.rs @@ -156,6 +156,8 @@ pub fn resolve_table_references( | Statement::ShowColumns { .. } | Statement::ShowTables { .. } | Statement::ShowCollation { .. } + | Statement::ShowSchemas { .. } + | Statement::ShowDatabases { .. } ); if requires_information_schema { for s in INFORMATION_SCHEMA_TABLES { diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index c5874deb6ed5..19d0fae207be 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -285,7 +285,9 @@ impl SessionState { .build() } - pub(crate) fn resolve_table_ref( + /// Resolves a [`TableReference`] to a [`ResolvedTableReference`] + /// using the default catalog and schema. + pub fn resolve_table_ref( &self, table_ref: impl Into, ) -> ResolvedTableReference { @@ -845,9 +847,9 @@ impl SessionState { overwrite: bool, ) -> Result<(), DataFusionError> { let ext = file_format.get_ext().to_lowercase(); - match (self.file_formats.entry(ext.clone()), overwrite){ - (Entry::Vacant(e), _) => {e.insert(file_format);}, - (Entry::Occupied(mut e), true) => {e.insert(file_format);}, + match (self.file_formats.entry(ext.clone()), overwrite) { + (Entry::Vacant(e), _) => { e.insert(file_format); } + (Entry::Occupied(mut e), true) => { e.insert(file_format); } (Entry::Occupied(_), false) => return config_err!("File type already registered for extension {ext}. Set overwrite to true to replace this extension."), }; Ok(()) diff --git a/datafusion/functions/src/regex/mod.rs b/datafusion/functions/src/regex/mod.rs index 13fbc049af58..c3b695a8cb6e 100644 --- a/datafusion/functions/src/regex/mod.rs +++ b/datafusion/functions/src/regex/mod.rs @@ -23,12 +23,14 @@ pub mod regexpcount; pub mod regexplike; pub mod regexpmatch; pub mod regexpreplace; +pub mod regexpsubstr; // create UDFs make_udf_function!(regexpcount::RegexpCountFunc, regexp_count); make_udf_function!(regexpmatch::RegexpMatchFunc, regexp_match); make_udf_function!(regexplike::RegexpLikeFunc, regexp_like); make_udf_function!(regexpreplace::RegexpReplaceFunc, regexp_replace); +make_udf_function!(regexpsubstr::RegexpSubstrFunc, regexp_substr); pub mod expr_fn { use datafusion_expr::Expr; @@ -60,6 +62,31 @@ pub mod expr_fn { super::regexp_match().call(args) } + /// Returns the substring that matches a regular expression within a string. + pub fn regexp_substr( + values: Expr, + regex: Expr, + start: Option, + occurrence: Option, + flags: Option, + group_num: Option, + ) -> Expr { + let mut args = vec![values, regex]; + if let Some(start) = start { + args.push(start); + }; + if let Some(occurrence) = occurrence { + args.push(occurrence); + }; + if let Some(flags) = flags { + args.push(flags); + }; + if let Some(group_num) = group_num { + args.push(group_num); + }; + super::regexp_substr().call(args) + } + /// Returns true if a has at least one match in a string, false otherwise. pub fn regexp_like(values: Expr, regex: Expr, flags: Option) -> Expr { let mut args = vec![values, regex]; diff --git a/datafusion/functions/src/regex/regexplike.rs b/datafusion/functions/src/regex/regexplike.rs index 1c826b12ef8f..c50f1af36acb 100644 --- a/datafusion/functions/src/regex/regexplike.rs +++ b/datafusion/functions/src/regex/regexplike.rs @@ -35,6 +35,7 @@ use std::sync::{Arc, OnceLock}; #[derive(Debug)] pub struct RegexpLikeFunc { signature: Signature, + aliases: Vec, } impl Default for RegexpLikeFunc { @@ -84,6 +85,7 @@ impl RegexpLikeFunc { vec![TypeSignature::String(2), TypeSignature::String(3)], Volatility::Immutable, ), + aliases: vec![String::from("rlike")], } } } @@ -112,6 +114,10 @@ impl ScalarUDFImpl for RegexpLikeFunc { }) } + fn aliases(&self) -> &[String] { + &self.aliases + } + fn invoke_batch( &self, args: &[ColumnarValue], diff --git a/datafusion/functions/src/regex/regexpsubstr.rs b/datafusion/functions/src/regex/regexpsubstr.rs new file mode 100644 index 000000000000..dffb8cea1db0 --- /dev/null +++ b/datafusion/functions/src/regex/regexpsubstr.rs @@ -0,0 +1,532 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Regex expressions +use arrow::array::{ + Array, ArrayRef, AsArray, GenericStringArray, GenericStringBuilder, ListBuilder, + OffsetSizeTrait, +}; +use arrow::datatypes::{DataType, Int32Type}; +use arrow::error::ArrowError; +use datafusion_common::plan_err; +use datafusion_common::ScalarValue; +use datafusion_common::{ + cast::as_generic_string_array, internal_err, DataFusionError, Result, +}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_REGEX; +use datafusion_expr::{ColumnarValue, Documentation, ScalarFunctionArgs, TypeSignature}; +use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use regex::Regex; +use std::any::Any; +use std::sync::{Arc, OnceLock}; + +#[derive(Debug)] +pub struct RegexpSubstrFunc { + signature: Signature, +} + +impl Default for RegexpSubstrFunc { + fn default() -> Self { + Self::new() + } +} + +impl RegexpSubstrFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![ + // Planner attempts coercion to the target type starting with the most preferred candidate. + // For example, given input `(Utf8View, Utf8)`, it first tries coercing to `(Utf8, Utf8)`. + // If that fails, it proceeds to `(LargeUtf8, Utf8)`. + TypeSignature::Exact(vec![Utf8, Utf8]), + TypeSignature::Exact(vec![LargeUtf8, LargeUtf8]), + TypeSignature::Exact(vec![Utf8, Utf8, Int32]), + TypeSignature::Exact(vec![LargeUtf8, LargeUtf8, Int32]), + TypeSignature::Exact(vec![Utf8, Utf8, Int32, Int32]), + TypeSignature::Exact(vec![LargeUtf8, LargeUtf8, Int32, Int32]), + TypeSignature::Exact(vec![Utf8, Utf8, Int32, Int32, Utf8]), + TypeSignature::Exact(vec![ + LargeUtf8, LargeUtf8, Int32, Int32, LargeUtf8, + ]), + TypeSignature::Exact(vec![Utf8, Utf8, Int32, Int32, Utf8, Int32]), + TypeSignature::Exact(vec![ + LargeUtf8, LargeUtf8, Int32, Int32, LargeUtf8, Int32, + ]), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for RegexpSubstrFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "regexp_substr" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(arg_types[0].clone()) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let len = args + .args + .iter() + .fold(Option::::None, |acc, arg| match arg { + ColumnarValue::Scalar(_) => acc, + ColumnarValue::Array(a) => Some(a.len()), + }); + + let is_scalar = len.is_none(); + let inferred_length = len.unwrap_or(1); + let args = args + .args + .iter() + .map(|arg| arg.to_array(inferred_length)) + .collect::>>()?; + + let result = regexp_match_func(&args); + if is_scalar { + // If all inputs are scalar, keeps output as scalar + let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0)); + result.map(ColumnarValue::Scalar) + } else { + result.map(ColumnarValue::Array) + } + } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_regexp_substr_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_regexp_substr_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder( + DOC_SECTION_REGEX, + "Returns the substring that matches a [regular expression](https://docs.rs/regex/latest/regex/#syntax) within a string.", + "regexp_substr(str, regexp[, position[, occurrence[, flags[, group_num]]]])") + .with_sql_example(r#"```sql + > select regexp_substr('Köln', '[a-zA-Z]ö[a-zA-Z]{2}'); + +---------------------------------------------------------+ + | regexp_substr(Utf8("Köln"),Utf8("[a-zA-Z]ö[a-zA-Z]{2}")) | + +---------------------------------------------------------+ + | [Köln] | + +---------------------------------------------------------+ + SELECT regexp_substr('aBc', '(b|d)', 1, 1, 'i'); + +---------------------------------------------------+ + | regexp_substr(Utf8("aBc"),Utf8("(b|d)"), Int32(1), Int32(1), Utf8("i")) | + +---------------------------------------------------+ + | [B] | + +---------------------------------------------------+ +``` +Additional examples can be found [here](https://docs.snowflake.com/en/sql-reference/functions/regexp_substr#examples) +"#) + .with_standard_argument("str", Some("String")) + .with_argument("regexp", "Regular expression to match against. + Can be a constant, column, or function.") + .with_argument("position", "Number of characters from the beginning of the string where the function starts searching for matches. Default: 1") + .with_argument("occurrence", "Specifies the first occurrence of the pattern from which to start returning matches.. Default: 1") + .with_argument("flags", + r#"Optional regular expression flags that control the behavior of the regular expression. The following flags are supported: + - **i**: case-insensitive: letters match both upper and lower case + - **m**: multi-line mode: ^ and $ match begin/end of line + - **s**: allow . to match \n + - **R**: enables CRLF mode: when multi-line mode is enabled, \r\n is used + - **U**: swap the meaning of x* and x*?"#) + .with_argument("group_num", "Specifies which group to extract. Groups are specified by using parentheses in the regular expression.") + .build() + }) +} + +fn regexp_match_func(args: &[ArrayRef]) -> Result { + match args[0].data_type() { + DataType::Utf8 => regexp_substr::(args), + DataType::LargeUtf8 => regexp_substr::(args), + other => { + internal_err!("Unsupported data type {other:?} for function regexp_substr") + } + } +} +pub fn regexp_substr(args: &[ArrayRef]) -> Result { + let args_len = args.len(); + let get_int_arg = |index: usize, name: &str| -> Result> { + if args_len > index { + let arg = args[index].as_primitive::(); + if arg.is_empty() { + return plan_err!( + "regexp_substr() requires the {:?} argument to be an integer", + name + ); + } + Ok(Some(arg.value(0))) + } else { + Ok(None) + } + }; + + let values = as_generic_string_array::(&args[0])?; + let regex = Some(as_generic_string_array::(&args[1])?.value(0)); + let start = get_int_arg(2, "position")?; + let occurrence = get_int_arg(3, "occurrence")?; + let flags = if args_len > 4 { + let flags = args[4].as_string::(); + if flags.iter().any(|s| s == Some("g")) { + return plan_err!("regexp_substr() does not support the \"global\" option"); + } + Some(flags.value(0)) + } else { + None + }; + + let group_num = get_int_arg(5, "group_num")?; + + let result = + regexp_substr_inner::(values, regex, start, occurrence, flags, group_num)?; + Ok(Arc::new(result) as ArrayRef) +} + +fn regexp_substr_inner( + values: &GenericStringArray, + regex: Option<&str>, + start: Option, + occurrence: Option, + flags: Option<&str>, + group_num: Option, +) -> Result { + let regex = match regex { + None | Some("") => { + return Ok( + Arc::new(GenericStringArray::::new_null(values.len())) as ArrayRef + ) + } + Some(regex) => regex, + }; + let regex = compile_regex(regex, flags)?; + let mut list_builder = ListBuilder::new(GenericStringBuilder::::new()); + + values.iter().try_for_each(|value| { + match value { + Some(value) => { + // Skip characters from the beginning + let cleaned_value = if let Some(start) = start { + if start < 1 { + return Err(DataFusionError::from(ArrowError::ComputeError( + "regexp_count() requires start to be 1 based".to_string(), + ))); + } + value.chars().skip(start as usize - 1).collect() + } else { + value.to_string() + }; + + let matches = + get_matches(cleaned_value.as_str(), ®ex, occurrence, group_num); + if !matches.is_empty() { + // Return only first substring that matches the pattern + if let Some(first_match) = matches.first() { + list_builder.values().append_value(first_match); + list_builder.append(true); + } + } else { + list_builder.append(false); + } + } + _ => list_builder.append(false), + } + Ok(()) + })?; + Ok(Arc::new(list_builder.finish())) +} + +fn get_matches( + value: &str, + regex: &Regex, + occurrence: Option, + group_num: Option, +) -> Vec { + let mut matches = Vec::new(); + let occurrence = occurrence.unwrap_or(1) as usize; + + for caps in regex.captures_iter(value) { + match group_num { + Some(group_num) => { + if let Some(m) = caps.get(group_num as usize) { + matches.push(m.as_str().to_string()); + } + } + None => { + let mut iter = caps.iter(); + if caps.len() > 1 { + iter.next(); + } + for m in iter.flatten() { + matches.push(m.as_str().to_string()); + } + } + } + } + + if matches.len() > occurrence { + matches = matches.split_off(occurrence - 1); + } + matches +} +fn compile_regex(regex: &str, flags: Option<&str>) -> Result { + let pattern = match flags { + None | Some("") => regex.to_string(), + Some(flags) => { + if flags.contains("g") { + return Err(ArrowError::ComputeError( + "regexp_substr() does not support global flag".to_string(), + )); + } + format!("(?{}){}", flags, regex) + } + }; + + Regex::new(&pattern).map_err(|_| { + ArrowError::ComputeError(format!( + "Regular expression did not compile: {}", + pattern + )) + }) +} + +#[cfg(test)] +mod tests { + use crate::regex::regexpsubstr::{regexp_substr, RegexpSubstrFunc}; + use arrow::array::{Array, ArrayAccessor, AsArray, Int32Array, StringArray}; + use arrow::datatypes::DataType; + use datafusion_common::ScalarValue; + use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl}; + use datafusion_expr_common::columnar_value::ColumnarValue; + use itertools::Itertools; + use std::sync::Arc; + + #[test] + fn test_regexp_substr() { + let values = [ + "Hellooo Woorld", + "How are you doing today floor?", + "the quick brown fox jumps over the lazy dog door", + "PACK MY BOX WITH FIVE DOZEN LIQUOR JUGS", + ]; + let regex = ["\\b\\S*o\\S*\\b", "(..or)"]; + let expected = [ + ["Hellooo", "How", "brown", ""], + ["Woor", "loor", "door", ""], + ]; + + // Scalar + values.iter().enumerate().for_each(|(pos, &value)| { + regex.iter().enumerate().for_each(|(rpos, regex)| { + let expected = expected.get(rpos).unwrap().get(pos).unwrap().to_string(); + + // Utf8 + let result = + RegexpSubstrFunc::new().invoke_with_args(ScalarFunctionArgs { + args: vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some( + value.to_string(), + ))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some( + regex.to_string(), + ))), + ], + number_rows: 0, + return_type: &DataType::Utf8, + }); + + match result { + Ok(ColumnarValue::Scalar(ScalarValue::List(v))) => { + let res = v.value(0); + if res.is_empty() { + assert_eq!( + "", expected, + "regexp_substr scalar utf8 test failed" + ); + } else { + let value = res.as_string::().value(0); + assert_eq!( + value, + expected.to_string(), + "regexp_substr scalar utf8 test failed" + ); + } + } + _ => panic!("Unexpected utf8 result"), + } + + // LargeUtf8 + let result = + RegexpSubstrFunc::new().invoke_with_args(ScalarFunctionArgs { + args: vec![ + ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some( + value.to_string(), + ))), + ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some( + regex.to_string(), + ))), + ], + number_rows: 0, + return_type: &DataType::LargeUtf8, + }); + + match result { + Ok(ColumnarValue::Scalar(ScalarValue::List(v))) => { + let res = v.value(0); + if res.is_empty() { + assert_eq!("", expected, "regexp_substr scalar test failed"); + } else { + let value = res.as_string::().value(0); + assert_eq!( + value, + expected.to_string(), + "regexp_substr scalar test failed" + ); + } + } + _ => panic!("Unexpected result"), + } + }); + }) + } + + #[test] + fn test_regexp_substr_with_params() { + let values = [ + "", + "aabca aabca", + "abc abc", + "Abcab abc", + "abCab cabc", + "ab", + ]; + let regex = "abc"; + let position = 1; + let occurrence = 1; + let flags = "i"; + let group_num = 1; + let expected = ["", "abc", "abc", "Abc", "abC", ""]; + + // Scalar + values.iter().enumerate().for_each(|(pos, &value)| { + let expected = expected.get(pos).cloned().unwrap(); + + // Utf8 + let result = RegexpSubstrFunc::new().invoke_with_args(ScalarFunctionArgs { + args: vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(value.to_string()))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(regex.to_string()))), + ColumnarValue::Scalar(ScalarValue::Int32(Some(position))), + ColumnarValue::Scalar(ScalarValue::Int32(Some(occurrence))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(flags.to_string()))), + // ColumnarValue::Scalar(ScalarValue::Int32(Some(group_num))), + ], + number_rows: 0, + return_type: &DataType::Utf8, + }); + + match result { + Ok(ColumnarValue::Scalar(ScalarValue::List(v))) => { + let res = v.value(0); + if res.is_empty() { + assert_eq!("", expected, "regexp_substr scalar utf8 test failed"); + } else { + let value = res.as_string::().value(0); + assert_eq!( + value, + expected.to_string(), + "regexp_substr scalar utf8 test failed" + ); + } + } + _ => panic!("Unexpected utf8 result"), + } + + // LargeUtf8 + let result = RegexpSubstrFunc::new().invoke_with_args(ScalarFunctionArgs { + args: vec![ + ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some( + value.to_string(), + ))), + ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some( + regex.to_string(), + ))), + ColumnarValue::Scalar(ScalarValue::Int32(Some(position))), + ColumnarValue::Scalar(ScalarValue::Int32(Some(occurrence))), + ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some( + flags.to_string(), + ))), + // ColumnarValue::Scalar(ScalarValue::Int32(Some(group_num))), + ], + number_rows: 0, + return_type: &DataType::LargeUtf8, + }); + + match result { + Ok(ColumnarValue::Scalar(ScalarValue::List(v))) => { + let res = v.value(0); + if res.is_empty() { + assert_eq!("", expected, "regexp_substr scalar test failed"); + } else { + let value = res.as_string::().value(0); + assert_eq!( + value, + expected.to_string(), + "regexp_substr scalar test failed" + ); + } + } + _ => panic!("Unexpected result"), + } + }); + } + + #[test] + fn test_unsupported_global_flag_regexp_substr() { + let values = StringArray::from(vec!["abc"]); + let patterns = StringArray::from(vec!["^(a)"]); + let position = Int32Array::from(vec![1]); + let occurrence = Int32Array::from(vec![1]); + let flags = StringArray::from(vec!["g"]); + + let re_err = regexp_substr::(&[ + Arc::new(values), + Arc::new(patterns), + Arc::new(position), + Arc::new(occurrence), + Arc::new(flags), + ]) + .expect_err("unsupported flag should have failed"); + + assert_eq!(re_err.strip_backtrace(), "Error during planning: regexp_substr() does not support the \"global\" option"); + } +} diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index ac0978683c36..fdb555894427 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -1769,6 +1769,7 @@ The following regular expression functions are supported: - [regexp_like](#regexp_like) - [regexp_match](#regexp_match) - [regexp_replace](#regexp_replace) +- [rlike](#rlike) ### `regexp_count` @@ -1839,6 +1840,10 @@ SELECT regexp_like('aBc', '(b|d)', 'i'); Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/regexp.rs) +#### Aliases + +- rlike + ### `regexp_match` Returns the first [regular expression](https://docs.rs/regex/latest/regex/#syntax) matches in a string. @@ -1919,6 +1924,10 @@ SELECT regexp_replace('aBc', '(b|d)', 'Ab\\1a', 'i'); Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/regexp.rs) +### `rlike` + +_Alias of [regexp_like](#regexp_like)._ + ## Time and Date Functions - [current_date](#current_date) From 3a223d1af501295dcf56b863e5998f1a1a76a6ed Mon Sep 17 00:00:00 2001 From: osipovartem Date: Wed, 22 Jan 2025 13:16:30 +0300 Subject: [PATCH 2/4] Add regexp_substr to docs --- datafusion/functions/src/regex/mod.rs | 1 + .../functions/src/regex/regexpsubstr.rs | 3 +- .../source/user-guide/sql/scalar_functions.md | 43 +++++++++++++++++++ 3 files changed, 45 insertions(+), 2 deletions(-) diff --git a/datafusion/functions/src/regex/mod.rs b/datafusion/functions/src/regex/mod.rs index c3b695a8cb6e..39ae0aa4663e 100644 --- a/datafusion/functions/src/regex/mod.rs +++ b/datafusion/functions/src/regex/mod.rs @@ -118,5 +118,6 @@ pub fn functions() -> Vec> { regexp_match(), regexp_like(), regexp_replace(), + regexp_substr(), ] } diff --git a/datafusion/functions/src/regex/regexpsubstr.rs b/datafusion/functions/src/regex/regexpsubstr.rs index dffb8cea1db0..7cbf724804cc 100644 --- a/datafusion/functions/src/regex/regexpsubstr.rs +++ b/datafusion/functions/src/regex/regexpsubstr.rs @@ -524,8 +524,7 @@ mod tests { Arc::new(position), Arc::new(occurrence), Arc::new(flags), - ]) - .expect_err("unsupported flag should have failed"); + ]).expect_err("unsupported flag should have failed"); assert_eq!(re_err.strip_backtrace(), "Error during planning: regexp_substr() does not support the \"global\" option"); } diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index fdb555894427..50f33357bbda 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -1769,6 +1769,7 @@ The following regular expression functions are supported: - [regexp_like](#regexp_like) - [regexp_match](#regexp_match) - [regexp_replace](#regexp_replace) +- [regexp_substr](#regexp_substr) - [rlike](#rlike) ### `regexp_count` @@ -1924,6 +1925,48 @@ SELECT regexp_replace('aBc', '(b|d)', 'Ab\\1a', 'i'); Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/regexp.rs) +### `regexp_substr` + +Returns the substring that matches a [regular expression](https://docs.rs/regex/latest/regex/#syntax) within a string. + +``` +regexp_substr(str, regexp[, position[, occurrence[, flags[, group_num]]]]) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **regexp**: Regular expression to match against. + Can be a constant, column, or function. +- **position**: Number of characters from the beginning of the string where the function starts searching for matches. Default: 1 +- **occurrence**: Specifies the first occurrence of the pattern from which to start returning matches.. Default: 1 +- **flags**: Optional regular expression flags that control the behavior of the regular expression. The following flags are supported: + - **i**: case-insensitive: letters match both upper and lower case + - **m**: multi-line mode: ^ and $ match begin/end of line + - **s**: allow . to match \n + - **R**: enables CRLF mode: when multi-line mode is enabled, \r\n is used + - **U**: swap the meaning of x* and x*? +- **group_num**: Specifies which group to extract. Groups are specified by using parentheses in the regular expression. + +#### Example + +```sql + > select regexp_substr('Köln', '[a-zA-Z]ö[a-zA-Z]{2}'); + +---------------------------------------------------------+ + | regexp_substr(Utf8("Köln"),Utf8("[a-zA-Z]ö[a-zA-Z]{2}")) | + +---------------------------------------------------------+ + | [Köln] | + +---------------------------------------------------------+ + SELECT regexp_substr('aBc', '(b|d)', 1, 1, 'i'); + +---------------------------------------------------+ + | regexp_substr(Utf8("aBc"),Utf8("(b|d)"), Int32(1), Int32(1), Utf8("i")) | + +---------------------------------------------------+ + | [B] | + +---------------------------------------------------+ +``` + +Additional examples can be found [here](https://docs.snowflake.com/en/sql-reference/functions/regexp_substr#examples) + ### `rlike` _Alias of [regexp_like](#regexp_like)._ From bcf3d730232322b918eb3e01fe98dcb11ee7a353 Mon Sep 17 00:00:00 2001 From: osipovartem Date: Thu, 23 Jan 2025 14:00:19 +0300 Subject: [PATCH 3/4] Update comments and fix returned datatype --- .../functions/src/regex/regexpsubstr.rs | 345 ++++++++++-------- 1 file changed, 184 insertions(+), 161 deletions(-) diff --git a/datafusion/functions/src/regex/regexpsubstr.rs b/datafusion/functions/src/regex/regexpsubstr.rs index 7cbf724804cc..42bcd412378f 100644 --- a/datafusion/functions/src/regex/regexpsubstr.rs +++ b/datafusion/functions/src/regex/regexpsubstr.rs @@ -17,10 +17,9 @@ //! Regex expressions use arrow::array::{ - Array, ArrayRef, AsArray, GenericStringArray, GenericStringBuilder, ListBuilder, - OffsetSizeTrait, + Array, ArrayRef, AsArray, GenericStringArray, GenericStringBuilder, OffsetSizeTrait, }; -use arrow::datatypes::{DataType, Int32Type}; +use arrow::datatypes::{DataType, Int64Type}; use arrow::error::ArrowError; use datafusion_common::plan_err; use datafusion_common::ScalarValue; @@ -47,7 +46,7 @@ impl Default for RegexpSubstrFunc { impl RegexpSubstrFunc { pub fn new() -> Self { - use DataType::*; + use DataType::{Int64, LargeUtf8, Utf8}; Self { signature: Signature::one_of( vec![ @@ -56,17 +55,17 @@ impl RegexpSubstrFunc { // If that fails, it proceeds to `(LargeUtf8, Utf8)`. TypeSignature::Exact(vec![Utf8, Utf8]), TypeSignature::Exact(vec![LargeUtf8, LargeUtf8]), - TypeSignature::Exact(vec![Utf8, Utf8, Int32]), - TypeSignature::Exact(vec![LargeUtf8, LargeUtf8, Int32]), - TypeSignature::Exact(vec![Utf8, Utf8, Int32, Int32]), - TypeSignature::Exact(vec![LargeUtf8, LargeUtf8, Int32, Int32]), - TypeSignature::Exact(vec![Utf8, Utf8, Int32, Int32, Utf8]), + TypeSignature::Exact(vec![Utf8, Utf8, Int64]), + TypeSignature::Exact(vec![LargeUtf8, LargeUtf8, Int64]), + TypeSignature::Exact(vec![Utf8, Utf8, Int64, Int64]), + TypeSignature::Exact(vec![LargeUtf8, LargeUtf8, Int64, Int64]), + TypeSignature::Exact(vec![Utf8, Utf8, Int64, Int64, Utf8]), TypeSignature::Exact(vec![ - LargeUtf8, LargeUtf8, Int32, Int32, LargeUtf8, + LargeUtf8, LargeUtf8, Int64, Int64, LargeUtf8, ]), - TypeSignature::Exact(vec![Utf8, Utf8, Int32, Int32, Utf8, Int32]), + TypeSignature::Exact(vec![Utf8, Utf8, Int64, Int64, Utf8, Int64]), TypeSignature::Exact(vec![ - LargeUtf8, LargeUtf8, Int32, Int32, LargeUtf8, Int32, + LargeUtf8, LargeUtf8, Int64, Int64, LargeUtf8, Int64, ]), ], Volatility::Immutable, @@ -109,7 +108,7 @@ impl ScalarUDFImpl for RegexpSubstrFunc { .map(|arg| arg.to_array(inferred_length)) .collect::>>()?; - let result = regexp_match_func(&args); + let result = regexp_subst_func(&args); if is_scalar { // If all inputs are scalar, keeps output as scalar let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0)); @@ -137,13 +136,13 @@ fn get_regexp_substr_doc() -> &'static Documentation { +---------------------------------------------------------+ | regexp_substr(Utf8("Köln"),Utf8("[a-zA-Z]ö[a-zA-Z]{2}")) | +---------------------------------------------------------+ - | [Köln] | + | Köln | +---------------------------------------------------------+ SELECT regexp_substr('aBc', '(b|d)', 1, 1, 'i'); +---------------------------------------------------+ | regexp_substr(Utf8("aBc"),Utf8("(b|d)"), Int32(1), Int32(1), Utf8("i")) | +---------------------------------------------------+ - | [B] | + | B | +---------------------------------------------------+ ``` Additional examples can be found [here](https://docs.snowflake.com/en/sql-reference/functions/regexp_substr#examples) @@ -156,8 +155,10 @@ Additional examples can be found [here](https://docs.snowflake.com/en/sql-refere .with_argument("flags", r#"Optional regular expression flags that control the behavior of the regular expression. The following flags are supported: - **i**: case-insensitive: letters match both upper and lower case + - **c**: case-sensitive: letters match upper or lower case. Default flag - **m**: multi-line mode: ^ and $ match begin/end of line - **s**: allow . to match \n + - **e**: extract submatches (for Snowflake compatibility) - **R**: enables CRLF mode: when multi-line mode is enabled, \r\n is used - **U**: swap the meaning of x* and x*?"#) .with_argument("group_num", "Specifies which group to extract. Groups are specified by using parentheses in the regular expression.") @@ -165,7 +166,7 @@ Additional examples can be found [here](https://docs.snowflake.com/en/sql-refere }) } -fn regexp_match_func(args: &[ArrayRef]) -> Result { +fn regexp_subst_func(args: &[ArrayRef]) -> Result { match args[0].data_type() { DataType::Utf8 => regexp_substr::(args), DataType::LargeUtf8 => regexp_substr::(args), @@ -176,9 +177,9 @@ fn regexp_match_func(args: &[ArrayRef]) -> Result { } pub fn regexp_substr(args: &[ArrayRef]) -> Result { let args_len = args.len(); - let get_int_arg = |index: usize, name: &str| -> Result> { + let get_int_arg = |index: usize, name: &str| -> Result> { if args_len > index { - let arg = args[index].as_primitive::(); + let arg = args[index].as_primitive::(); if arg.is_empty() { return plan_err!( "regexp_substr() requires the {:?} argument to be an integer", @@ -209,27 +210,25 @@ pub fn regexp_substr(args: &[ArrayRef]) -> Result let result = regexp_substr_inner::(values, regex, start, occurrence, flags, group_num)?; - Ok(Arc::new(result) as ArrayRef) + Ok(Arc::new(result)) } fn regexp_substr_inner( values: &GenericStringArray, regex: Option<&str>, - start: Option, - occurrence: Option, + start: Option, + occurrence: Option, flags: Option<&str>, - group_num: Option, + group_num: Option, ) -> Result { let regex = match regex { None | Some("") => { - return Ok( - Arc::new(GenericStringArray::::new_null(values.len())) as ArrayRef - ) + return Ok(Arc::new(GenericStringArray::::new_null(values.len()))) } Some(regex) => regex, }; let regex = compile_regex(regex, flags)?; - let mut list_builder = ListBuilder::new(GenericStringBuilder::::new()); + let mut builder = GenericStringBuilder::::new(); values.iter().try_for_each(|value| { match value { @@ -248,28 +247,28 @@ fn regexp_substr_inner( let matches = get_matches(cleaned_value.as_str(), ®ex, occurrence, group_num); - if !matches.is_empty() { + + if matches.is_empty() { + builder.append_null(); + } else { // Return only first substring that matches the pattern if let Some(first_match) = matches.first() { - list_builder.values().append_value(first_match); - list_builder.append(true); + builder.append_value(first_match); } - } else { - list_builder.append(false); } } - _ => list_builder.append(false), + _ => builder.append_null(), } Ok(()) })?; - Ok(Arc::new(list_builder.finish())) + Ok(Arc::new(builder.finish())) } fn get_matches( value: &str, regex: &Regex, - occurrence: Option, - group_num: Option, + occurrence: Option, + group_num: Option, ) -> Vec { let mut matches = Vec::new(); let occurrence = occurrence.unwrap_or(1) as usize; @@ -307,6 +306,8 @@ fn compile_regex(regex: &str, flags: Option<&str>) -> Result "regexp_substr() does not support global flag".to_string(), )); } + // Case-sensitive enabled by default + let flags = flags.replace("c", ""); format!("(?{}){}", flags, regex) } }; @@ -322,12 +323,11 @@ fn compile_regex(regex: &str, flags: Option<&str>) -> Result #[cfg(test)] mod tests { use crate::regex::regexpsubstr::{regexp_substr, RegexpSubstrFunc}; - use arrow::array::{Array, ArrayAccessor, AsArray, Int32Array, StringArray}; + use arrow::array::{Array, ArrayRef, Int64Array, LargeStringArray, StringArray}; use arrow::datatypes::DataType; use datafusion_common::ScalarValue; use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl}; use datafusion_expr_common::columnar_value::ColumnarValue; - use itertools::Itertools; use std::sync::Arc; #[test] @@ -349,74 +349,120 @@ mod tests { regex.iter().enumerate().for_each(|(rpos, regex)| { let expected = expected.get(rpos).unwrap().get(pos).unwrap().to_string(); - // Utf8 - let result = - RegexpSubstrFunc::new().invoke_with_args(ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::Utf8(Some( - value.to_string(), - ))), - ColumnarValue::Scalar(ScalarValue::Utf8(Some( - regex.to_string(), - ))), - ], - number_rows: 0, - return_type: &DataType::Utf8, - }); - - match result { - Ok(ColumnarValue::Scalar(ScalarValue::List(v))) => { - let res = v.value(0); - if res.is_empty() { - assert_eq!( - "", expected, - "regexp_substr scalar utf8 test failed" - ); - } else { - let value = res.as_string::().value(0); - assert_eq!( - value, - expected.to_string(), - "regexp_substr scalar utf8 test failed" - ); + // Utf8, LargeUtf8 + for (data_type, scalar) in &[ + ( + DataType::Utf8, + ScalarValue::Utf8 as fn(Option) -> ScalarValue, + ), + ( + DataType::LargeUtf8, + ScalarValue::LargeUtf8 as fn(Option) -> ScalarValue, + ), + ] { + let result = + RegexpSubstrFunc::new().invoke_with_args(ScalarFunctionArgs { + args: vec![ + ColumnarValue::Scalar(scalar(Some(value.to_string()))), + ColumnarValue::Scalar(scalar(Some(regex.to_string()))), + ], + number_rows: 1, + return_type: data_type, + }); + match result { + Ok(ColumnarValue::Scalar( + ScalarValue::Utf8(ref res) | ScalarValue::LargeUtf8(ref res), + )) => { + if res.is_some() { + assert_eq!( + res.as_ref().unwrap(), + &expected.to_string(), + "regexp_substr scalar test failed" + ); + } else { + assert_eq!( + "", expected, + "regexp_substr scalar utf8 test failed" + ) + } } + _ => panic!("Unexpected result"), } - _ => panic!("Unexpected utf8 result"), } + }); + }); - // LargeUtf8 + // Array (column) + regex.iter().enumerate().for_each(|(rpos, regex)| { + // Utf8, LargeUtf8 + for data_type in &[DataType::Utf8, DataType::LargeUtf8] { + let (array_values, regex) = match data_type { + DataType::Utf8 => ( + Arc::new(StringArray::from( + values.iter().map(|v| v.to_string()).collect::>(), + )) as ArrayRef, + ScalarValue::Utf8(Some(regex.to_string())), + ), + DataType::LargeUtf8 => ( + Arc::new(LargeStringArray::from( + values.iter().map(|v| v.to_string()).collect::>(), + )) as ArrayRef, + ScalarValue::LargeUtf8(Some(regex.to_string())), + ), + _ => unreachable!(), + }; let result = RegexpSubstrFunc::new().invoke_with_args(ScalarFunctionArgs { args: vec![ - ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some( - value.to_string(), - ))), - ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some( - regex.to_string(), - ))), + ColumnarValue::Array(Arc::new(array_values)), + ColumnarValue::Scalar(regex), ], - number_rows: 0, - return_type: &DataType::LargeUtf8, + number_rows: 1, + return_type: data_type, }); - match result { - Ok(ColumnarValue::Scalar(ScalarValue::List(v))) => { - let res = v.value(0); - if res.is_empty() { - assert_eq!("", expected, "regexp_substr scalar test failed"); - } else { - let value = res.as_string::().value(0); - assert_eq!( - value, - expected.to_string(), - "regexp_substr scalar test failed" - ); - } + Ok(ColumnarValue::Array(array)) => { + let expected = expected + .get(rpos) + .unwrap() + .iter() + .map(|v| { + if v.is_empty() { + return None; + } + Some(v.to_string()) + }) + .collect::>>(); + + assert_eq!(array.data_type(), data_type, "wrong array datatype"); + match data_type { + DataType::Utf8 => { + let array = + array.as_any().downcast_ref::().unwrap(); + let expected = StringArray::from(expected); + assert_eq!( + array, &expected, + "regexp_substr array Utf8 test failed" + ); + } + DataType::LargeUtf8 => { + let array = array + .as_any() + .downcast_ref::() + .unwrap(); + let expected = LargeStringArray::from(expected); + assert_eq!( + array, &expected, + "regexp_substr array LargeUtf8 test failed" + ); + } + _ => unreachable!(), + }; } _ => panic!("Unexpected result"), } - }); - }) + } + }); } #[test] @@ -433,79 +479,55 @@ mod tests { let position = 1; let occurrence = 1; let flags = "i"; - let group_num = 1; + let group_num = 0; let expected = ["", "abc", "abc", "Abc", "abC", ""]; // Scalar values.iter().enumerate().for_each(|(pos, &value)| { let expected = expected.get(pos).cloned().unwrap(); - - // Utf8 - let result = RegexpSubstrFunc::new().invoke_with_args(ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::Utf8(Some(value.to_string()))), - ColumnarValue::Scalar(ScalarValue::Utf8(Some(regex.to_string()))), - ColumnarValue::Scalar(ScalarValue::Int32(Some(position))), - ColumnarValue::Scalar(ScalarValue::Int32(Some(occurrence))), - ColumnarValue::Scalar(ScalarValue::Utf8(Some(flags.to_string()))), - // ColumnarValue::Scalar(ScalarValue::Int32(Some(group_num))), - ], - number_rows: 0, - return_type: &DataType::Utf8, - }); - - match result { - Ok(ColumnarValue::Scalar(ScalarValue::List(v))) => { - let res = v.value(0); - if res.is_empty() { - assert_eq!("", expected, "regexp_substr scalar utf8 test failed"); - } else { - let value = res.as_string::().value(0); - assert_eq!( - value, - expected.to_string(), - "regexp_substr scalar utf8 test failed" - ); - } - } - _ => panic!("Unexpected utf8 result"), - } - - // LargeUtf8 - let result = RegexpSubstrFunc::new().invoke_with_args(ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some( - value.to_string(), - ))), - ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some( - regex.to_string(), - ))), - ColumnarValue::Scalar(ScalarValue::Int32(Some(position))), - ColumnarValue::Scalar(ScalarValue::Int32(Some(occurrence))), - ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some( - flags.to_string(), - ))), - // ColumnarValue::Scalar(ScalarValue::Int32(Some(group_num))), - ], - number_rows: 0, - return_type: &DataType::LargeUtf8, - }); - - match result { - Ok(ColumnarValue::Scalar(ScalarValue::List(v))) => { - let res = v.value(0); - if res.is_empty() { - assert_eq!("", expected, "regexp_substr scalar test failed"); - } else { - let value = res.as_string::().value(0); - assert_eq!( - value, - expected.to_string(), - "regexp_substr scalar test failed" - ); + // Utf8, LargeUtf8 + for (data_type, scalar) in &[ + ( + DataType::Utf8, + ScalarValue::Utf8 as fn(Option) -> ScalarValue, + ), + ( + DataType::LargeUtf8, + ScalarValue::LargeUtf8 as fn(Option) -> ScalarValue, + ), + ] { + let result = + RegexpSubstrFunc::new().invoke_with_args(ScalarFunctionArgs { + args: vec![ + ColumnarValue::Scalar(scalar(Some(value.to_string()))), + ColumnarValue::Scalar(scalar(Some(regex.to_string()))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(position))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(occurrence))), + ColumnarValue::Scalar(scalar(Some(flags.to_string()))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(group_num))), + ], + number_rows: 1, + return_type: data_type, + }); + match result { + Ok(ColumnarValue::Scalar( + ScalarValue::Utf8(ref res) | ScalarValue::LargeUtf8(ref res), + )) => { + if res.is_some() { + assert_eq!( + res.as_ref().unwrap(), + &expected.to_string(), + "regexp_substr scalar test failed" + ); + } else { + assert_eq!( + "", expected, + "regexp_substr scalar utf8 test failed" + ) + } } + _ => panic!("Unexpected result"), } - _ => panic!("Unexpected result"), } }); } @@ -514,8 +536,8 @@ mod tests { fn test_unsupported_global_flag_regexp_substr() { let values = StringArray::from(vec!["abc"]); let patterns = StringArray::from(vec!["^(a)"]); - let position = Int32Array::from(vec![1]); - let occurrence = Int32Array::from(vec![1]); + let position = Int64Array::from(vec![1]); + let occurrence = Int64Array::from(vec![1]); let flags = StringArray::from(vec!["g"]); let re_err = regexp_substr::(&[ @@ -524,7 +546,8 @@ mod tests { Arc::new(position), Arc::new(occurrence), Arc::new(flags), - ]).expect_err("unsupported flag should have failed"); + ]) + .expect_err("unsupported flag should have failed"); assert_eq!(re_err.strip_backtrace(), "Error during planning: regexp_substr() does not support the \"global\" option"); } From 63db707a8f7bbc6305718db91284b05bcd1b1611 Mon Sep 17 00:00:00 2001 From: osipovartem Date: Thu, 23 Jan 2025 14:22:56 +0300 Subject: [PATCH 4/4] Update docs --- docs/source/user-guide/sql/scalar_functions.md | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index 50f33357bbda..769e730026f5 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -1942,8 +1942,10 @@ regexp_substr(str, regexp[, position[, occurrence[, flags[, group_num]]]]) - **occurrence**: Specifies the first occurrence of the pattern from which to start returning matches.. Default: 1 - **flags**: Optional regular expression flags that control the behavior of the regular expression. The following flags are supported: - **i**: case-insensitive: letters match both upper and lower case + - **c**: case-sensitive: letters match upper or lower case. Default flag - **m**: multi-line mode: ^ and $ match begin/end of line - **s**: allow . to match \n + - **e**: extract submatches (for Snowflake compatibility) - **R**: enables CRLF mode: when multi-line mode is enabled, \r\n is used - **U**: swap the meaning of x* and x*? - **group_num**: Specifies which group to extract. Groups are specified by using parentheses in the regular expression. @@ -1955,13 +1957,13 @@ regexp_substr(str, regexp[, position[, occurrence[, flags[, group_num]]]]) +---------------------------------------------------------+ | regexp_substr(Utf8("Köln"),Utf8("[a-zA-Z]ö[a-zA-Z]{2}")) | +---------------------------------------------------------+ - | [Köln] | + | Köln | +---------------------------------------------------------+ SELECT regexp_substr('aBc', '(b|d)', 1, 1, 'i'); +---------------------------------------------------+ | regexp_substr(Utf8("aBc"),Utf8("(b|d)"), Int32(1), Int32(1), Utf8("i")) | +---------------------------------------------------+ - | [B] | + | B | +---------------------------------------------------+ ```