@@ -615,14 +615,25 @@ mod tests {
615615 /// Test three-way join: customer -> orders -> lineitem
616616 #[ test]
617617 fn test_three_way_join_customer_orders_lineitem ( ) -> Result < ( ) > {
618+ use datafusion_expr:: test:: function_stub:: sum;
619+ use datafusion_expr:: { col, in_subquery, lit} ;
618620 // Create the base table scans with statistics
619- let customer = scan_tpch_table_with_stats ( "customer" , 150 ) ;
620- let orders = scan_tpch_table_with_stats ( "orders" , 1500 ) ;
621- let lineitem = scan_tpch_table_with_stats ( "lineitem" , 6000 ) ;
621+ // Create the base table scans with statistics
622+ let customer = scan_tpch_table_with_stats ( "customer" , 150_000 ) ;
623+ let orders = scan_tpch_table_with_stats ( "orders" , 1_500_000 ) ;
624+ let lineitem = scan_tpch_table_with_stats ( "lineitem" , 6_000_000 ) ;
625+
626+ // Step 1: Build the subquery
627+ // SELECT l_orderkey FROM lineitem
628+ // GROUP BY l_orderkey
629+ // HAVING sum(l_quantity) > 300
630+ let subquery = LogicalPlanBuilder :: from ( lineitem. clone ( ) )
631+ . aggregate ( vec ! [ col( "l_orderkey" ) ] , vec ! [ sum( col( "l_quantity" ) ) ] ) ?
632+ . filter ( sum ( col ( "l_quantity" ) ) . gt ( lit ( 300 ) ) ) ?
633+ . project ( vec ! [ col( "l_orderkey" ) ] ) ?
634+ . build ( ) ?;
622635
623- // Build a join plan: customer JOIN orders JOIN lineitem
624- // customer.c_custkey = orders.o_custkey
625- // orders.o_orderkey = lineitem.l_orderkey
636+ // Step 2: Build the main query with joins
626637 let plan = LogicalPlanBuilder :: from ( customer. clone ( ) )
627638 . join (
628639 orders. clone ( ) ,
@@ -636,6 +647,22 @@ mod tests {
636647 ( vec ! [ "o_orderkey" ] , vec ! [ "l_orderkey" ] ) ,
637648 None ,
638649 ) ?
650+ // Step 3: Apply the IN subquery filter
651+ . filter ( in_subquery ( col ( "o_orderkey" ) , Arc :: new ( subquery) ) ) ?
652+ // Step 4: Aggregate
653+ . aggregate (
654+ vec ! [
655+ col( "c_name" ) ,
656+ col( "c_custkey" ) ,
657+ col( "o_orderkey" ) ,
658+ col( "o_totalprice" ) ,
659+ ] ,
660+ vec ! [ sum( col( "l_quantity" ) ) ] ,
661+ ) ?
662+ // Step 5: Sort
663+ . sort ( vec ! [ col( "o_totalprice" ) . sort( false , true ) ] ) ?
664+ // Step 6: Limit
665+ . limit ( 0 , Some ( 100 ) ) ?
639666 . build ( ) ?;
640667
641668 let query_graph = QueryGraph :: try_from ( plan) . unwrap ( ) ;
0 commit comments