Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 77 additions & 57 deletions datafusion/functions/src/regex/regexpsubstr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,14 @@ fn regexp_substr_inner<T: OffsetSizeTrait>(
}
Some(regex) => regex,
};

// Check for 'e' flag and set group_num to 1 if not provided
let group_num = if flags.is_some_and(|f| f.contains('e')) {
group_num.or(Some(1))
} else {
group_num
};

let regex = compile_regex(regex, flags)?;
let mut builder = GenericStringBuilder::<T>::new();

Expand All @@ -247,7 +255,6 @@ fn regexp_substr_inner<T: OffsetSizeTrait>(

let matches =
get_matches(cleaned_value.as_str(), &regex, occurrence, group_num);

if matches.is_empty() {
builder.append_null();
} else {
Expand Down Expand Up @@ -307,8 +314,12 @@ fn compile_regex(regex: &str, flags: Option<&str>) -> Result<Regex, ArrowError>
));
}
// Case-sensitive enabled by default
let flags = flags.replace("c", "");
format!("(?{}){}", flags, regex)
let flags = flags.replace("c", "").replace("e", "");
if flags.is_empty() {
regex.to_string()
} else {
format!("(?{}){}", flags, regex)
}
}
};

Expand Down Expand Up @@ -469,66 +480,75 @@ mod tests {
fn test_regexp_substr_with_params() {
let values = [
"",
"aabca aabca",
"abc abc",
"Abcab abc",
"abCab cabc",
"ab",
"aabc aabca vff ddf",
"abc abca abcD vff",
"Abcab abcD caddd",
"abCab cabcd dasaaabc VfFddd",
"ab dasacabd caBcv dasaaabcdv",
];
let regex = ["abc", "(abc\\S)|(bca)", "(abc)|(bca)", "(abc)|(vff)|(d)"];
let flags = ["i", "ie", "e", "i"];
let group_num = [0, 1, 0, 2];
let expected = [
["", "abc", "abc", "Abc", "abC", "aBc"],
["", "abca", "abca", "Abca", "abCa", "aBcv"],
["", "abc", "abc", "bca", "abc", "abc"],
["", "vff", "vff", "", "VfF", ""],
];
let regex = "abc";
let position = 1;
let occurrence = 1;
let flags = "i";
let group_num = 0;
let expected = ["", "abc", "abc", "Abc", "abC", ""];

// Scalar
values.iter().enumerate().for_each(|(pos, &value)| {
let expected = expected.get(pos).cloned().unwrap();
// Utf8, LargeUtf8
for (data_type, scalar) in &[
(
DataType::Utf8,
ScalarValue::Utf8 as fn(Option<String>) -> ScalarValue,
),
(
DataType::LargeUtf8,
ScalarValue::LargeUtf8 as fn(Option<String>) -> ScalarValue,
),
] {
let result =
RegexpSubstrFunc::new().invoke_with_args(ScalarFunctionArgs {
args: vec![
ColumnarValue::Scalar(scalar(Some(value.to_string()))),
ColumnarValue::Scalar(scalar(Some(regex.to_string()))),
ColumnarValue::Scalar(ScalarValue::Int64(Some(position))),
ColumnarValue::Scalar(ScalarValue::Int64(Some(occurrence))),
ColumnarValue::Scalar(scalar(Some(flags.to_string()))),
ColumnarValue::Scalar(ScalarValue::Int64(Some(group_num))),
],
number_rows: 1,
return_type: data_type,
});
match result {
Ok(ColumnarValue::Scalar(
ScalarValue::Utf8(ref res) | ScalarValue::LargeUtf8(ref res),
)) => {
if res.is_some() {
assert_eq!(
res.as_ref().unwrap(),
&expected.to_string(),
"regexp_substr scalar test failed"
);
} else {
assert_eq!(
"", expected,
"regexp_substr scalar utf8 test failed"
)
regex.iter().enumerate().for_each(|(spos, &regex)| {
values.iter().enumerate().for_each(|(pos, &value)| {
let expected = expected.get(spos).unwrap().get(pos).cloned().unwrap();
// Utf8, LargeUtf8
for (data_type, scalar) in &[
(
DataType::Utf8,
ScalarValue::Utf8 as fn(Option<String>) -> ScalarValue,
),
(
DataType::LargeUtf8,
ScalarValue::LargeUtf8 as fn(Option<String>) -> ScalarValue,
),
] {
let result =
RegexpSubstrFunc::new().invoke_with_args(ScalarFunctionArgs {
args: vec![
ColumnarValue::Scalar(scalar(Some(value.to_string()))),
ColumnarValue::Scalar(scalar(Some(regex.to_string()))),
ColumnarValue::Scalar(ScalarValue::Int64(Some(1))),
ColumnarValue::Scalar(ScalarValue::Int64(Some(1))),
ColumnarValue::Scalar(scalar(Some(
flags[spos].to_string(),
))),
ColumnarValue::Scalar(ScalarValue::Int64(Some(
group_num[spos],
))),
],
number_rows: 1,
return_type: data_type,
});
match result {
Ok(ColumnarValue::Scalar(
ScalarValue::Utf8(ref res) | ScalarValue::LargeUtf8(ref res),
)) => {
if res.is_some() {
assert_eq!(
res.as_ref().unwrap(),
&expected.to_string(),
"regexp_substr scalar test failed"
);
} else {
assert_eq!(
"", expected,
"regexp_substr scalar utf8 test failed"
)
}
}
_ => panic!("Unexpected result"),
}
_ => panic!("Unexpected result"),
}
}
})
});
}

Expand Down
Loading