diff --git a/datafusion/functions/src/regex/regexpsubstr.rs b/datafusion/functions/src/regex/regexpsubstr.rs index 42bcd412378f..5c6d9900ce19 100644 --- a/datafusion/functions/src/regex/regexpsubstr.rs +++ b/datafusion/functions/src/regex/regexpsubstr.rs @@ -227,6 +227,14 @@ fn regexp_substr_inner( } Some(regex) => regex, }; + + // Check for 'e' flag and set group_num to 1 if not provided + let group_num = if flags.is_some_and(|f| f.contains('e')) { + group_num.or(Some(1)) + } else { + group_num + }; + let regex = compile_regex(regex, flags)?; let mut builder = GenericStringBuilder::::new(); @@ -247,7 +255,6 @@ fn regexp_substr_inner( let matches = get_matches(cleaned_value.as_str(), ®ex, occurrence, group_num); - if matches.is_empty() { builder.append_null(); } else { @@ -307,8 +314,12 @@ fn compile_regex(regex: &str, flags: Option<&str>) -> Result )); } // Case-sensitive enabled by default - let flags = flags.replace("c", ""); - format!("(?{}){}", flags, regex) + let flags = flags.replace("c", "").replace("e", ""); + if flags.is_empty() { + regex.to_string() + } else { + format!("(?{}){}", flags, regex) + } } }; @@ -469,66 +480,75 @@ mod tests { fn test_regexp_substr_with_params() { let values = [ "", - "aabca aabca", - "abc abc", - "Abcab abc", - "abCab cabc", - "ab", + "aabc aabca vff ddf", + "abc abca abcD vff", + "Abcab abcD caddd", + "abCab cabcd dasaaabc VfFddd", + "ab dasacabd caBcv dasaaabcdv", + ]; + let regex = ["abc", "(abc\\S)|(bca)", "(abc)|(bca)", "(abc)|(vff)|(d)"]; + let flags = ["i", "ie", "e", "i"]; + let group_num = [0, 1, 0, 2]; + let expected = [ + ["", "abc", "abc", "Abc", "abC", "aBc"], + ["", "abca", "abca", "Abca", "abCa", "aBcv"], + ["", "abc", "abc", "bca", "abc", "abc"], + ["", "vff", "vff", "", "VfF", ""], ]; - let regex = "abc"; - let position = 1; - let occurrence = 1; - let flags = "i"; - let group_num = 0; - let expected = ["", "abc", "abc", "Abc", "abC", ""]; // Scalar - values.iter().enumerate().for_each(|(pos, &value)| { - let expected = expected.get(pos).cloned().unwrap(); - // Utf8, LargeUtf8 - for (data_type, scalar) in &[ - ( - DataType::Utf8, - ScalarValue::Utf8 as fn(Option) -> ScalarValue, - ), - ( - DataType::LargeUtf8, - ScalarValue::LargeUtf8 as fn(Option) -> ScalarValue, - ), - ] { - let result = - RegexpSubstrFunc::new().invoke_with_args(ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(scalar(Some(value.to_string()))), - ColumnarValue::Scalar(scalar(Some(regex.to_string()))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(position))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(occurrence))), - ColumnarValue::Scalar(scalar(Some(flags.to_string()))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(group_num))), - ], - number_rows: 1, - return_type: data_type, - }); - match result { - Ok(ColumnarValue::Scalar( - ScalarValue::Utf8(ref res) | ScalarValue::LargeUtf8(ref res), - )) => { - if res.is_some() { - assert_eq!( - res.as_ref().unwrap(), - &expected.to_string(), - "regexp_substr scalar test failed" - ); - } else { - assert_eq!( - "", expected, - "regexp_substr scalar utf8 test failed" - ) + regex.iter().enumerate().for_each(|(spos, ®ex)| { + values.iter().enumerate().for_each(|(pos, &value)| { + let expected = expected.get(spos).unwrap().get(pos).cloned().unwrap(); + // Utf8, LargeUtf8 + for (data_type, scalar) in &[ + ( + DataType::Utf8, + ScalarValue::Utf8 as fn(Option) -> ScalarValue, + ), + ( + DataType::LargeUtf8, + ScalarValue::LargeUtf8 as fn(Option) -> ScalarValue, + ), + ] { + let result = + RegexpSubstrFunc::new().invoke_with_args(ScalarFunctionArgs { + args: vec![ + ColumnarValue::Scalar(scalar(Some(value.to_string()))), + ColumnarValue::Scalar(scalar(Some(regex.to_string()))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(1))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(1))), + ColumnarValue::Scalar(scalar(Some( + flags[spos].to_string(), + ))), + ColumnarValue::Scalar(ScalarValue::Int64(Some( + group_num[spos], + ))), + ], + number_rows: 1, + return_type: data_type, + }); + match result { + Ok(ColumnarValue::Scalar( + ScalarValue::Utf8(ref res) | ScalarValue::LargeUtf8(ref res), + )) => { + if res.is_some() { + assert_eq!( + res.as_ref().unwrap(), + &expected.to_string(), + "regexp_substr scalar test failed" + ); + } else { + assert_eq!( + "", expected, + "regexp_substr scalar utf8 test failed" + ) + } } + _ => panic!("Unexpected result"), } - _ => panic!("Unexpected result"), } - } + }) }); }