diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs index dcf477135a37..2bceff979c86 100644 --- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs @@ -113,6 +113,32 @@ async fn test_first_val() { .await; } +#[tokio::test(flavor = "multi_thread")] +async fn test_last_val() { + let mut data_gen_config = baseline_config(); + + for i in 0..data_gen_config.columns.len() { + if data_gen_config.columns[i].get_max_num_distinct().is_none() { + data_gen_config.columns[i] = data_gen_config.columns[i] + .clone() + // Minimize the chance of identical values in the order by columns to make the test more stable + .with_max_num_distinct(usize::MAX); + } + } + + let query_builder = QueryBuilder::new() + .with_table_name("fuzz_table") + .with_aggregate_function("last_value") + .with_aggregate_arguments(data_gen_config.all_columns()) + .set_group_by_columns(data_gen_config.all_columns()); + + AggregationFuzzerBuilder::from(data_gen_config) + .add_query_builder(query_builder) + .build() + .run() + .await; +} + #[tokio::test(flavor = "multi_thread")] async fn test_max() { let data_gen_config = baseline_config(); diff --git a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs index bb24fb554d65..c37417bd43c0 100644 --- a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs +++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs @@ -503,7 +503,9 @@ impl QueryBuilder { let distinct = if *is_distinct { "DISTINCT " } else { "" }; alias_gen += 1; - let (order_by, null_opt) = if function_name.eq("first_value") { + let (order_by, null_opt) = if function_name.eq("first_value") + || function_name.eq("last_value") + { ( self.order_by(&order_by_black_list), /* Among the order by columns, at most one group by column can be included to avoid all order by column values being identical */ self.null_opt(), diff --git a/datafusion/functions-aggregate/src/first_last.rs b/datafusion/functions-aggregate/src/first_last.rs index 28e6a8723dfd..646543637535 100644 --- a/datafusion/functions-aggregate/src/first_last.rs +++ b/datafusion/functions-aggregate/src/first_last.rs @@ -166,6 +166,7 @@ impl AggregateUDFImpl for FirstValue { } fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { + // TODO: extract to function use DataType::*; matches!( args.return_type, @@ -193,6 +194,7 @@ impl AggregateUDFImpl for FirstValue { &self, args: AccumulatorArgs, ) -> Result> { + // TODO: extract to function fn create_accumulator( args: AccumulatorArgs, ) -> Result> @@ -210,6 +212,7 @@ impl AggregateUDFImpl for FirstValue { args.ignore_nulls, args.return_type, &ordering_dtypes, + true, )?)) } @@ -258,10 +261,12 @@ impl AggregateUDFImpl for FirstValue { create_accumulator::(args) } - _ => internal_err!( - "GroupsAccumulator not supported for first({})", - args.return_type - ), + _ => { + internal_err!( + "GroupsAccumulator not supported for first_value({})", + args.return_type + ) + } } } @@ -291,6 +296,7 @@ impl AggregateUDFImpl for FirstValue { } } +// TODO: rename to PrimitiveGroupsAccumulator struct FirstPrimitiveGroupsAccumulator where T: ArrowPrimitiveType + Send, @@ -316,12 +322,16 @@ where // buffer for `get_filtered_min_of_each_group` // filter_min_of_each_group_buf.0[group_idx] -> idx_in_val // only valid if filter_min_of_each_group_buf.1[group_idx] == true + // TODO: rename to extreme_of_each_group_buf min_of_each_group_buf: (Vec, BooleanBufferBuilder), // =========== option ============ // Stores the applicable ordering requirement. ordering_req: LexOrdering, + // true: take first element in an aggregation group according to the requested ordering. + // false: take last element in an aggregation group according to the requested ordering. + pick_first_in_group: bool, // derived from `ordering_req`. sort_options: Vec, // Stores whether incoming data already satisfies the ordering requirement. @@ -342,6 +352,7 @@ where ignore_nulls: bool, data_type: &DataType, ordering_dtypes: &[DataType], + pick_first_in_group: bool, ) -> Result { let requirement_satisfied = ordering_req.is_empty(); @@ -365,6 +376,7 @@ where is_sets: BooleanBufferBuilder::new(0), size_of_orderings: 0, min_of_each_group_buf: (Vec::new(), BooleanBufferBuilder::new(0)), + pick_first_in_group, }) } @@ -391,8 +403,13 @@ where assert!(new_ordering_values.len() == self.ordering_req.len()); let current_ordering = &self.orderings[group_idx]; - compare_rows(current_ordering, new_ordering_values, &self.sort_options) - .map(|x| x.is_gt()) + compare_rows(current_ordering, new_ordering_values, &self.sort_options).map(|x| { + if self.pick_first_in_group { + x.is_gt() + } else { + x.is_lt() + } + }) } fn take_orderings(&mut self, emit_to: EmitTo) -> Vec> { @@ -501,10 +518,10 @@ where .map(ScalarValue::size_of_vec) .sum::() } - /// Returns a vector of tuples `(group_idx, idx_in_val)` representing the index of the /// minimum value in `orderings` for each group, using lexicographical comparison. /// Values are filtered using `opt_filter` and `is_set_arr` if provided. + /// TODO: rename to get_filtered_extreme_of_each_group fn get_filtered_min_of_each_group( &mut self, orderings: &[ArrayRef], @@ -556,15 +573,19 @@ where } let is_valid = self.min_of_each_group_buf.1.get_bit(group_idx); - if is_valid - && comparator - .compare(self.min_of_each_group_buf.0[group_idx], idx_in_val) - .is_gt() - { - self.min_of_each_group_buf.0[group_idx] = idx_in_val; - } else if !is_valid { + + if !is_valid { self.min_of_each_group_buf.1.set_bit(group_idx, true); self.min_of_each_group_buf.0[group_idx] = idx_in_val; + } else { + let ordering = comparator + .compare(self.min_of_each_group_buf.0[group_idx], idx_in_val); + + if (ordering.is_gt() && self.pick_first_in_group) + || (ordering.is_lt() && !self.pick_first_in_group) + { + self.min_of_each_group_buf.0[group_idx] = idx_in_val; + } } } @@ -1052,6 +1073,109 @@ impl AggregateUDFImpl for LastValue { fn documentation(&self) -> Option<&Documentation> { self.doc() } + + fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { + use DataType::*; + matches!( + args.return_type, + Int8 | Int16 + | Int32 + | Int64 + | UInt8 + | UInt16 + | UInt32 + | UInt64 + | Float16 + | Float32 + | Float64 + | Decimal128(_, _) + | Decimal256(_, _) + | Date32 + | Date64 + | Time32(_) + | Time64(_) + | Timestamp(_, _) + ) + } + + fn create_groups_accumulator( + &self, + args: AccumulatorArgs, + ) -> Result> { + fn create_accumulator( + args: AccumulatorArgs, + ) -> Result> + where + T: ArrowPrimitiveType + Send, + { + let ordering_dtypes = args + .ordering_req + .iter() + .map(|e| e.expr.data_type(args.schema)) + .collect::>>()?; + + Ok(Box::new(FirstPrimitiveGroupsAccumulator::::try_new( + args.ordering_req.clone(), + args.ignore_nulls, + args.return_type, + &ordering_dtypes, + false, + )?)) + } + + match args.return_type { + DataType::Int8 => create_accumulator::(args), + DataType::Int16 => create_accumulator::(args), + DataType::Int32 => create_accumulator::(args), + DataType::Int64 => create_accumulator::(args), + DataType::UInt8 => create_accumulator::(args), + DataType::UInt16 => create_accumulator::(args), + DataType::UInt32 => create_accumulator::(args), + DataType::UInt64 => create_accumulator::(args), + DataType::Float16 => create_accumulator::(args), + DataType::Float32 => create_accumulator::(args), + DataType::Float64 => create_accumulator::(args), + + DataType::Decimal128(_, _) => create_accumulator::(args), + DataType::Decimal256(_, _) => create_accumulator::(args), + + DataType::Timestamp(TimeUnit::Second, _) => { + create_accumulator::(args) + } + DataType::Timestamp(TimeUnit::Millisecond, _) => { + create_accumulator::(args) + } + DataType::Timestamp(TimeUnit::Microsecond, _) => { + create_accumulator::(args) + } + DataType::Timestamp(TimeUnit::Nanosecond, _) => { + create_accumulator::(args) + } + + DataType::Date32 => create_accumulator::(args), + DataType::Date64 => create_accumulator::(args), + DataType::Time32(TimeUnit::Second) => { + create_accumulator::(args) + } + DataType::Time32(TimeUnit::Millisecond) => { + create_accumulator::(args) + } + + DataType::Time64(TimeUnit::Microsecond) => { + create_accumulator::(args) + } + DataType::Time64(TimeUnit::Nanosecond) => { + create_accumulator::(args) + } + + _ => { + internal_err!( + "GroupsAccumulator not supported for last_value({})", + args.return_type + ) + } + } + } } #[derive(Debug)] @@ -1411,6 +1535,7 @@ mod tests { true, &DataType::Int64, &[DataType::Int64], + true, )?; let mut val_with_orderings = { @@ -1485,7 +1610,7 @@ mod tests { } #[test] - fn test_frist_group_acc_size_of_ordering() -> Result<()> { + fn test_group_acc_size_of_ordering() -> Result<()> { let schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Int64, true), Field::new("b", DataType::Int64, true), @@ -1504,6 +1629,7 @@ mod tests { true, &DataType::Int64, &[DataType::Int64], + true, )?; let val_with_orderings = { @@ -1563,4 +1689,79 @@ mod tests { Ok(()) } + + #[test] + fn test_last_group_acc() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, true), + Field::new("b", DataType::Int64, true), + Field::new("c", DataType::Int64, true), + Field::new("d", DataType::Int32, true), + Field::new("e", DataType::Boolean, true), + ])); + + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { + expr: col("c", &schema).unwrap(), + options: SortOptions::default(), + }]); + + let mut group_acc = FirstPrimitiveGroupsAccumulator::::try_new( + sort_key, + true, + &DataType::Int64, + &[DataType::Int64], + false, + )?; + + let mut val_with_orderings = { + let mut val_with_orderings = Vec::::new(); + + let vals = Arc::new(Int64Array::from(vec![Some(1), None, Some(3), Some(-6)])); + let orderings = Arc::new(Int64Array::from(vec![1, -9, 3, -6])); + + val_with_orderings.push(vals); + val_with_orderings.push(orderings); + + val_with_orderings + }; + + group_acc.update_batch( + &val_with_orderings, + &[0, 1, 2, 1], + Some(&BooleanArray::from(vec![true, true, false, true])), + 3, + )?; + + let state = group_acc.state(EmitTo::All)?; + + let expected_state: Vec> = vec![ + Arc::new(Int64Array::from(vec![Some(1), Some(-6), None])), + Arc::new(Int64Array::from(vec![Some(1), Some(-6), None])), + Arc::new(BooleanArray::from(vec![true, true, false])), + ]; + assert_eq!(state, expected_state); + + group_acc.merge_batch( + &state, + &[0, 1, 2], + Some(&BooleanArray::from(vec![true, false, false])), + 3, + )?; + + val_with_orderings.clear(); + val_with_orderings.push(Arc::new(Int64Array::from(vec![66, 6]))); + val_with_orderings.push(Arc::new(Int64Array::from(vec![66, 6]))); + + group_acc.update_batch(&val_with_orderings, &[1, 2], None, 4)?; + + let binding = group_acc.evaluate(EmitTo::All)?; + let eval_result = binding.as_any().downcast_ref::().unwrap(); + + let expect: PrimitiveArray = + Int64Array::from(vec![Some(1), Some(66), Some(6), None]); + + assert_eq!(eval_result, &expect); + + Ok(()) + } } diff --git a/datafusion/sqllogictest/test_files/group_by.slt b/datafusion/sqllogictest/test_files/group_by.slt index 4c4999a364d1..9e67018ecd0b 100644 --- a/datafusion/sqllogictest/test_files/group_by.slt +++ b/datafusion/sqllogictest/test_files/group_by.slt @@ -2232,7 +2232,7 @@ physical_plan 03)----StreamingTableExec: partition_sizes=1, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST] query III -SELECT a, b, LAST_VALUE(c) as last_c +SELECT a, b, LAST_VALUE(c order by c) as last_c FROM annotated_data_infinite2 GROUP BY a, b ---- @@ -2706,6 +2706,29 @@ select k, first_value(val order by o) respect NULLS from first_null group by k; 1 1 +statement ok +CREATE TABLE last_null ( + k INT, + val INT, + o int + ) as VALUES + (0, NULL, 9), + (0, 1, 1), + (1, 1, 1); + +query II rowsort +select k, last_value(val order by o) IGNORE NULLS from last_null group by k; +---- +0 1 +1 1 + +query II rowsort +select k, last_value(val order by o) respect NULLS from last_null group by k; +---- +0 NULL +1 1 + + query TT EXPLAIN SELECT country, ARRAY_AGG(amount ORDER BY amount DESC) AS amounts, FIRST_VALUE(amount ORDER BY amount ASC) AS fv1, @@ -3775,7 +3798,7 @@ ORDER BY x; 2 2 query II -SELECT y, LAST_VALUE(x) +SELECT y, LAST_VALUE(x order by x desc) FROM FOO GROUP BY y ORDER BY y;