Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,32 @@ async fn test_first_val() {
.await;
}

#[tokio::test(flavor = "multi_thread")]
async fn test_last_val() {
let mut data_gen_config = baseline_config();

for i in 0..data_gen_config.columns.len() {
if data_gen_config.columns[i].get_max_num_distinct().is_none() {
data_gen_config.columns[i] = data_gen_config.columns[i]
.clone()
// Minimize the chance of identical values in the order by columns to make the test more stable
.with_max_num_distinct(usize::MAX);
}
}

let query_builder = QueryBuilder::new()
.with_table_name("fuzz_table")
.with_aggregate_function("last_value")
.with_aggregate_arguments(data_gen_config.all_columns())
.set_group_by_columns(data_gen_config.all_columns());

AggregationFuzzerBuilder::from(data_gen_config)
.add_query_builder(query_builder)
.build()
.run()
.await;
}

#[tokio::test(flavor = "multi_thread")]
async fn test_max() {
let data_gen_config = baseline_config();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,9 @@ impl QueryBuilder {
let distinct = if *is_distinct { "DISTINCT " } else { "" };
alias_gen += 1;

let (order_by, null_opt) = if function_name.eq("first_value") {
let (order_by, null_opt) = if function_name.eq("first_value")
|| function_name.eq("last_value")
{
(
self.order_by(&order_by_black_list), /* Among the order by columns, at most one group by column can be included to avoid all order by column values being identical */
self.null_opt(),
Expand Down
231 changes: 216 additions & 15 deletions datafusion/functions-aggregate/src/first_last.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ impl AggregateUDFImpl for FirstValue {
}

fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
// TODO: extract to function
use DataType::*;
matches!(
args.return_type,
Expand Down Expand Up @@ -193,6 +194,7 @@ impl AggregateUDFImpl for FirstValue {
&self,
args: AccumulatorArgs,
) -> Result<Box<dyn GroupsAccumulator>> {
// TODO: extract to function
fn create_accumulator<T>(
args: AccumulatorArgs,
) -> Result<Box<dyn GroupsAccumulator>>
Expand All @@ -210,6 +212,7 @@ impl AggregateUDFImpl for FirstValue {
args.ignore_nulls,
args.return_type,
&ordering_dtypes,
true,
)?))
}

Expand Down Expand Up @@ -258,10 +261,12 @@ impl AggregateUDFImpl for FirstValue {
create_accumulator::<Time64NanosecondType>(args)
}

_ => internal_err!(
"GroupsAccumulator not supported for first({})",
args.return_type
),
_ => {
internal_err!(
"GroupsAccumulator not supported for first_value({})",
args.return_type
)
}
}
}

Expand Down Expand Up @@ -291,6 +296,7 @@ impl AggregateUDFImpl for FirstValue {
}
}

// TODO: rename to PrimitiveGroupsAccumulator
Copy link
Contributor

@jayzhan211 jayzhan211 Apr 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's better to include this kind of improvement in the same PR to avoid confusion. Only changes that are highly independent should be considered for splitting into smaller PRs

struct FirstPrimitiveGroupsAccumulator<T>
where
T: ArrowPrimitiveType + Send,
Expand All @@ -316,12 +322,16 @@ where
// buffer for `get_filtered_min_of_each_group`
// filter_min_of_each_group_buf.0[group_idx] -> idx_in_val
// only valid if filter_min_of_each_group_buf.1[group_idx] == true
// TODO: rename to extreme_of_each_group_buf
min_of_each_group_buf: (Vec<usize>, BooleanBufferBuilder),

// =========== option ============

// Stores the applicable ordering requirement.
ordering_req: LexOrdering,
// true: take first element in an aggregation group according to the requested ordering.
// false: take last element in an aggregation group according to the requested ordering.
pick_first_in_group: bool,
// derived from `ordering_req`.
sort_options: Vec<SortOptions>,
// Stores whether incoming data already satisfies the ordering requirement.
Expand All @@ -342,6 +352,7 @@ where
ignore_nulls: bool,
data_type: &DataType,
ordering_dtypes: &[DataType],
pick_first_in_group: bool,
) -> Result<Self> {
let requirement_satisfied = ordering_req.is_empty();

Expand All @@ -365,6 +376,7 @@ where
is_sets: BooleanBufferBuilder::new(0),
size_of_orderings: 0,
min_of_each_group_buf: (Vec::new(), BooleanBufferBuilder::new(0)),
pick_first_in_group,
})
}

Expand All @@ -391,8 +403,13 @@ where

assert!(new_ordering_values.len() == self.ordering_req.len());
let current_ordering = &self.orderings[group_idx];
compare_rows(current_ordering, new_ordering_values, &self.sort_options)
.map(|x| x.is_gt())
compare_rows(current_ordering, new_ordering_values, &self.sort_options).map(|x| {
if self.pick_first_in_group {
x.is_gt()
} else {
x.is_lt()
}
})
}

fn take_orderings(&mut self, emit_to: EmitTo) -> Vec<Vec<ScalarValue>> {
Expand Down Expand Up @@ -501,10 +518,10 @@ where
.map(ScalarValue::size_of_vec)
.sum::<usize>()
}

/// Returns a vector of tuples `(group_idx, idx_in_val)` representing the index of the
/// minimum value in `orderings` for each group, using lexicographical comparison.
/// Values are filtered using `opt_filter` and `is_set_arr` if provided.
/// TODO: rename to get_filtered_extreme_of_each_group
fn get_filtered_min_of_each_group(
&mut self,
orderings: &[ArrayRef],
Expand Down Expand Up @@ -556,15 +573,19 @@ where
}

let is_valid = self.min_of_each_group_buf.1.get_bit(group_idx);
if is_valid
&& comparator
.compare(self.min_of_each_group_buf.0[group_idx], idx_in_val)
.is_gt()
{
self.min_of_each_group_buf.0[group_idx] = idx_in_val;
} else if !is_valid {

if !is_valid {
self.min_of_each_group_buf.1.set_bit(group_idx, true);
self.min_of_each_group_buf.0[group_idx] = idx_in_val;
} else {
let ordering = comparator
.compare(self.min_of_each_group_buf.0[group_idx], idx_in_val);

if (ordering.is_gt() && self.pick_first_in_group)
|| (ordering.is_lt() && !self.pick_first_in_group)
{
self.min_of_each_group_buf.0[group_idx] = idx_in_val;
}
}
}

Expand Down Expand Up @@ -1052,6 +1073,109 @@ impl AggregateUDFImpl for LastValue {
fn documentation(&self) -> Option<&Documentation> {
self.doc()
}

fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
use DataType::*;
matches!(
args.return_type,
Int8 | Int16
| Int32
| Int64
| UInt8
| UInt16
| UInt32
| UInt64
| Float16
| Float32
| Float64
| Decimal128(_, _)
| Decimal256(_, _)
| Date32
| Date64
| Time32(_)
| Time64(_)
| Timestamp(_, _)
)
}

fn create_groups_accumulator(
&self,
args: AccumulatorArgs,
) -> Result<Box<dyn GroupsAccumulator>> {
fn create_accumulator<T>(
args: AccumulatorArgs,
) -> Result<Box<dyn GroupsAccumulator>>
where
T: ArrowPrimitiveType + Send,
{
let ordering_dtypes = args
.ordering_req
.iter()
.map(|e| e.expr.data_type(args.schema))
.collect::<Result<Vec<_>>>()?;

Ok(Box::new(FirstPrimitiveGroupsAccumulator::<T>::try_new(
args.ordering_req.clone(),
args.ignore_nulls,
args.return_type,
&ordering_dtypes,
false,
)?))
}

match args.return_type {
DataType::Int8 => create_accumulator::<Int8Type>(args),
DataType::Int16 => create_accumulator::<Int16Type>(args),
DataType::Int32 => create_accumulator::<Int32Type>(args),
DataType::Int64 => create_accumulator::<Int64Type>(args),
DataType::UInt8 => create_accumulator::<UInt8Type>(args),
DataType::UInt16 => create_accumulator::<UInt16Type>(args),
DataType::UInt32 => create_accumulator::<UInt32Type>(args),
DataType::UInt64 => create_accumulator::<UInt64Type>(args),
DataType::Float16 => create_accumulator::<Float16Type>(args),
DataType::Float32 => create_accumulator::<Float32Type>(args),
DataType::Float64 => create_accumulator::<Float64Type>(args),

DataType::Decimal128(_, _) => create_accumulator::<Decimal128Type>(args),
DataType::Decimal256(_, _) => create_accumulator::<Decimal256Type>(args),

DataType::Timestamp(TimeUnit::Second, _) => {
create_accumulator::<TimestampSecondType>(args)
}
DataType::Timestamp(TimeUnit::Millisecond, _) => {
create_accumulator::<TimestampMillisecondType>(args)
}
DataType::Timestamp(TimeUnit::Microsecond, _) => {
create_accumulator::<TimestampMicrosecondType>(args)
}
DataType::Timestamp(TimeUnit::Nanosecond, _) => {
create_accumulator::<TimestampNanosecondType>(args)
}

DataType::Date32 => create_accumulator::<Date32Type>(args),
DataType::Date64 => create_accumulator::<Date64Type>(args),
DataType::Time32(TimeUnit::Second) => {
create_accumulator::<Time32SecondType>(args)
}
DataType::Time32(TimeUnit::Millisecond) => {
create_accumulator::<Time32MillisecondType>(args)
}

DataType::Time64(TimeUnit::Microsecond) => {
create_accumulator::<Time64MicrosecondType>(args)
}
DataType::Time64(TimeUnit::Nanosecond) => {
create_accumulator::<Time64NanosecondType>(args)
}

_ => {
internal_err!(
"GroupsAccumulator not supported for last_value({})",
args.return_type
)
}
}
}
}

#[derive(Debug)]
Expand Down Expand Up @@ -1411,6 +1535,7 @@ mod tests {
true,
&DataType::Int64,
&[DataType::Int64],
true,
)?;

let mut val_with_orderings = {
Expand Down Expand Up @@ -1485,7 +1610,7 @@ mod tests {
}

#[test]
fn test_frist_group_acc_size_of_ordering() -> Result<()> {
fn test_group_acc_size_of_ordering() -> Result<()> {
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int64, true),
Field::new("b", DataType::Int64, true),
Expand All @@ -1504,6 +1629,7 @@ mod tests {
true,
&DataType::Int64,
&[DataType::Int64],
true,
)?;

let val_with_orderings = {
Expand Down Expand Up @@ -1563,4 +1689,79 @@ mod tests {

Ok(())
}

#[test]
fn test_last_group_acc() -> Result<()> {
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int64, true),
Field::new("b", DataType::Int64, true),
Field::new("c", DataType::Int64, true),
Field::new("d", DataType::Int32, true),
Field::new("e", DataType::Boolean, true),
]));

let sort_key = LexOrdering::new(vec![PhysicalSortExpr {
expr: col("c", &schema).unwrap(),
options: SortOptions::default(),
}]);

let mut group_acc = FirstPrimitiveGroupsAccumulator::<Int64Type>::try_new(
sort_key,
true,
&DataType::Int64,
&[DataType::Int64],
false,
)?;

let mut val_with_orderings = {
let mut val_with_orderings = Vec::<ArrayRef>::new();

let vals = Arc::new(Int64Array::from(vec![Some(1), None, Some(3), Some(-6)]));
let orderings = Arc::new(Int64Array::from(vec![1, -9, 3, -6]));

val_with_orderings.push(vals);
val_with_orderings.push(orderings);

val_with_orderings
};

group_acc.update_batch(
&val_with_orderings,
&[0, 1, 2, 1],
Some(&BooleanArray::from(vec![true, true, false, true])),
3,
)?;

let state = group_acc.state(EmitTo::All)?;

let expected_state: Vec<Arc<dyn Array>> = vec![
Arc::new(Int64Array::from(vec![Some(1), Some(-6), None])),
Arc::new(Int64Array::from(vec![Some(1), Some(-6), None])),
Arc::new(BooleanArray::from(vec![true, true, false])),
];
assert_eq!(state, expected_state);

group_acc.merge_batch(
&state,
&[0, 1, 2],
Some(&BooleanArray::from(vec![true, false, false])),
3,
)?;

val_with_orderings.clear();
val_with_orderings.push(Arc::new(Int64Array::from(vec![66, 6])));
val_with_orderings.push(Arc::new(Int64Array::from(vec![66, 6])));

group_acc.update_batch(&val_with_orderings, &[1, 2], None, 4)?;

let binding = group_acc.evaluate(EmitTo::All)?;
let eval_result = binding.as_any().downcast_ref::<Int64Array>().unwrap();

let expect: PrimitiveArray<Int64Type> =
Int64Array::from(vec![Some(1), Some(66), Some(6), None]);

assert_eq!(eval_result, &expect);

Ok(())
}
}
Loading