From b26d272dbe7e1a7dc5dcf4829b8f99b4ecaf2c16 Mon Sep 17 00:00:00 2001 From: AppCoder1234 Date: Wed, 16 Feb 2022 15:05:19 +0100 Subject: [PATCH 1/4] Code assist convert_if_to_filter Detect simple cases where if can be transformed into filter --- .../src/handlers/lambdify_for_each.rs | 162 ++++++++++++++++++ crates/ide_assists/src/lib.rs | 2 + crates/ide_assists/src/tests/generated.rs | 28 +++ 3 files changed, 192 insertions(+) create mode 100644 crates/ide_assists/src/handlers/lambdify_for_each.rs diff --git a/crates/ide_assists/src/handlers/lambdify_for_each.rs b/crates/ide_assists/src/handlers/lambdify_for_each.rs new file mode 100644 index 000000000000..c5f9b92bce6d --- /dev/null +++ b/crates/ide_assists/src/handlers/lambdify_for_each.rs @@ -0,0 +1,162 @@ +use ide_db::helpers::FamousDefs; +use stdx::format_to; +use syntax::{ + SyntaxKind, + ast::{self, edit_in_place::Indent, HasArgList, Pat, Expr}, + AstNode, +}; + +use crate::{AssistContext, AssistId, AssistKind, Assists}; + +// Assist: convert_if_to_filter +// +// Converts a if into a filter when placed in a for_each(). +// ``` +// # //- minicore: iterators +// # use core::iter; +// fn main() { +// let it = core::iter::repeat(92); +// it.for_each$0(|x| { +// if x > 4 { +// println!("{}", x); +// }; +// }); +// } +// ``` +// -> +// ``` +// # use core::iter; +// fn main() { +// let it = core::iter::repeat(92); +// it.filter(|&x| x > 4).for_each(|x| { +// println!("{}", x); +// }); +// } +// ``` +pub(crate) fn convert_if_to_filter(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { + let method = ctx.find_node_at_offset::()?; + + let closure = match method.arg_list()?.args().next()? { + ast::Expr::ClosureExpr(expr) => expr, + _ => return None, + }; + + let (method, receiver) = validate_method_call_expr(ctx, method)?; + + let param_list = closure.param_list()?; + let param = param_list.params().next()?.pat()?; + let body = closure.body()?; + + let range = method.syntax().text_range(); + + let if_expr = match body.clone() { + Expr::IfExpr(if_expr) => { + if_expr + }, + Expr::BlockExpr(block) => { + let mut stmts = block.statements(); + let fst_stmt = stmts.next()?; + continue_iff(stmts.next().is_none())?; // Only one statement + // First statement is an expression... + let expr_stmt = match fst_stmt { + ast::Stmt::ExprStmt(expr_stmt) => expr_stmt, + _ => return None, + }; + + // ...and even an if clause... + let expr = expr_stmt.expr()?; + let if_expr = match expr { + ast::Expr::IfExpr(my_if_expr) => my_if_expr, + _ => return None, + }; + if_expr + }, + _ => return None, + }; + + let condition = if_expr.condition()?; // ... with a condition... + continue_iff(if_expr.else_branch().is_none()); // ... and no else branch... + let then_branch = if_expr.then_branch()?; // ... and a then branch + // ... and the pattern in the for loop is an ident pattern + let ident_pat = match param { + Pat::IdentPat(ref ident_pat) => ident_pat.clone(), + _ => return None, + }; + + acc.add( + AssistId("convert_if_to_filter", AssistKind::RefactorRewrite), + "Replace this `if { ... }` with a `filter()`", + range, + |builder| { + let indent = method.indent_level(); + + let mut buf = String::new(); + // Remove unnecessary `mut` in the pattern if used in a filter + let pat_filter = ident_pat.clone_for_update(); + if let Some(mut_token) = pat_filter.mut_token() { + if let Some(ws) = mut_token.next_token().filter(|it| it.kind() == SyntaxKind::WHITESPACE) { + ws.detach(); + } + mut_token.detach(); + } + format_to!(buf, "{}.filter(|&{}| {})", receiver, pat_filter, condition); + + // Because we removed a if block, reident accordingly the rest of the block + let block = then_branch.clone_for_update(); + block.reindent_to(indent); + + format_to!(buf, ".for_each(|{}| {})", param, block); + + builder.replace(range, buf) + }, + ) +} + +fn validate_method_call_expr( + ctx: &AssistContext, + expr: ast::MethodCallExpr, +) -> Option<(ast::Expr, ast::Expr)> { + let name_ref = expr.name_ref()?; + if name_ref.text() != "for_each" { + return None; + } + + let receiver = expr.receiver()?; + let expr = ast::Expr::MethodCallExpr(expr); + + Some((expr, receiver)) +} + +fn continue_iff(b: bool) -> Option<()> { + if b { Some(()) } else { None } +} + +#[cfg(test)] +mod tests { + use crate::tests::check_assist; + + use super::*; + + #[test] + fn if_to_filter() { + check_assist( + convert_if_to_filter, + r#" +fn main() { + let it = core::iter::repeat(92); + it.for_each$0(|mut i| { + if (i*i)%3 == 2 { + i *= 2; + }; + }); +}"#, + r#" +fn main() { + let it = core::iter::repeat(92); + it.filter(|&i| (i*i)%3 == 2).for_each(|mut i| { + i *= 2; + }); +}"#, + ) + } +} diff --git a/crates/ide_assists/src/lib.rs b/crates/ide_assists/src/lib.rs index 067f4d8e14d0..b16e48d92b00 100644 --- a/crates/ide_assists/src/lib.rs +++ b/crates/ide_assists/src/lib.rs @@ -151,6 +151,7 @@ mod handlers { mod inline_call; mod inline_local_variable; mod introduce_named_lifetime; + mod lambdify_for_each; mod invert_if; mod merge_imports; mod merge_match_arms; @@ -234,6 +235,7 @@ mod handlers { introduce_named_generic::introduce_named_generic, introduce_named_lifetime::introduce_named_lifetime, invert_if::invert_if, + lambdify_for_each::convert_if_to_filter, merge_imports::merge_imports, merge_match_arms::merge_match_arms, move_bounds::move_bounds_to_where_clause, diff --git a/crates/ide_assists/src/tests/generated.rs b/crates/ide_assists/src/tests/generated.rs index 0ad4b3bc345c..48d7abaf7b61 100644 --- a/crates/ide_assists/src/tests/generated.rs +++ b/crates/ide_assists/src/tests/generated.rs @@ -297,6 +297,34 @@ fn main() { ) } +#[test] +fn doctest_convert_if_to_filter() { + check_doc_test( + "convert_if_to_filter", + r#####" +//- minicore: iterators +use core::iter; +fn main() { + let it = core::iter::repeat(92); + it.for_each$0(|x| { + if x > 4 { + println!("{}", x); + }; + }); +} +"#####, + r#####" +use core::iter; +fn main() { + let it = core::iter::repeat(92); + it.filter(|&x| x > 4).for_each(|x| { + println!("{}", x); + }); +} +"#####, + ) +} + #[test] fn doctest_convert_integer_literal() { check_doc_test( From f73296e7cff5fc604de79b14dcf5f7b2bb30f2ab Mon Sep 17 00:00:00 2001 From: AppCoder1234 Date: Wed, 16 Feb 2022 16:34:23 +0100 Subject: [PATCH 2/4] If to filter code assist Recursively remove mut idents in filter closure --- .../src/handlers/lambdify_for_each.rs | 34 ++++++++++--------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/crates/ide_assists/src/handlers/lambdify_for_each.rs b/crates/ide_assists/src/handlers/lambdify_for_each.rs index c5f9b92bce6d..7dbff96b6547 100644 --- a/crates/ide_assists/src/handlers/lambdify_for_each.rs +++ b/crates/ide_assists/src/handlers/lambdify_for_each.rs @@ -1,10 +1,10 @@ -use ide_db::helpers::FamousDefs; use stdx::format_to; use syntax::{ SyntaxKind, ast::{self, edit_in_place::Indent, HasArgList, Pat, Expr}, AstNode, }; +use ide_db::helpers::node_ext::walk_pat; use crate::{AssistContext, AssistId, AssistKind, Assists}; @@ -41,7 +41,7 @@ pub(crate) fn convert_if_to_filter(acc: &mut Assists, ctx: &AssistContext) -> Op _ => return None, }; - let (method, receiver) = validate_method_call_expr(ctx, method)?; + let (method, receiver) = validate_method_call_expr(method)?; let param_list = closure.param_list()?; let param = param_list.params().next()?.pat()?; @@ -77,11 +77,6 @@ pub(crate) fn convert_if_to_filter(acc: &mut Assists, ctx: &AssistContext) -> Op let condition = if_expr.condition()?; // ... with a condition... continue_iff(if_expr.else_branch().is_none()); // ... and no else branch... let then_branch = if_expr.then_branch()?; // ... and a then branch - // ... and the pattern in the for loop is an ident pattern - let ident_pat = match param { - Pat::IdentPat(ref ident_pat) => ident_pat.clone(), - _ => return None, - }; acc.add( AssistId("convert_if_to_filter", AssistKind::RefactorRewrite), @@ -91,9 +86,17 @@ pub(crate) fn convert_if_to_filter(acc: &mut Assists, ctx: &AssistContext) -> Op let indent = method.indent_level(); let mut buf = String::new(); - // Remove unnecessary `mut` in the pattern if used in a filter - let pat_filter = ident_pat.clone_for_update(); - if let Some(mut_token) = pat_filter.mut_token() { + // Recursively remove unnecessary `mut`s in the parameter + let pat_filter = param.clone_for_update(); + let mut to_be_removed = vec![]; + walk_pat(&pat_filter, &mut |cb| + if let Pat::IdentPat(ident) = cb { + if let Some(mut_token) = ident.mut_token() { + to_be_removed.push(mut_token); + } + } + ); + for mut_token in to_be_removed.into_iter() { if let Some(ws) = mut_token.next_token().filter(|it| it.kind() == SyntaxKind::WHITESPACE) { ws.detach(); } @@ -113,7 +116,6 @@ pub(crate) fn convert_if_to_filter(acc: &mut Assists, ctx: &AssistContext) -> Op } fn validate_method_call_expr( - ctx: &AssistContext, expr: ast::MethodCallExpr, ) -> Option<(ast::Expr, ast::Expr)> { let name_ref = expr.name_ref()?; @@ -143,17 +145,17 @@ mod tests { convert_if_to_filter, r#" fn main() { - let it = core::iter::repeat(92); - it.for_each$0(|mut i| { - if (i*i)%3 == 2 { + let it = core::iter::repeat((92,42)); + it.for_each$0(|(mut i,mut j)| { + if (i*j)%3 == 2 { i *= 2; }; }); }"#, r#" fn main() { - let it = core::iter::repeat(92); - it.filter(|&i| (i*i)%3 == 2).for_each(|mut i| { + let it = core::iter::repeat((92,42)); + it.filter(|&(i,j)| (i*j)%3 == 2).for_each(|(mut i,mut j)| { i *= 2; }); }"#, From 21733c080978d87d280276468927cdc0ff9d6ecb Mon Sep 17 00:00:00 2001 From: AppCoder1234 Date: Wed, 9 Mar 2022 15:28:08 +0100 Subject: [PATCH 3/4] If to filter code assist Apply code assist only when cursor is well positioned --- crates/ide_assists/src/handlers/lambdify_for_each.rs | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/crates/ide_assists/src/handlers/lambdify_for_each.rs b/crates/ide_assists/src/handlers/lambdify_for_each.rs index 7dbff96b6547..4cd066701acb 100644 --- a/crates/ide_assists/src/handlers/lambdify_for_each.rs +++ b/crates/ide_assists/src/handlers/lambdify_for_each.rs @@ -41,7 +41,7 @@ pub(crate) fn convert_if_to_filter(acc: &mut Assists, ctx: &AssistContext) -> Op _ => return None, }; - let (method, receiver) = validate_method_call_expr(method)?; + let (method, receiver) = validate_method_call_expr(ctx, method)?; let param_list = closure.param_list()?; let param = param_list.params().next()?.pat()?; @@ -116,16 +116,22 @@ pub(crate) fn convert_if_to_filter(acc: &mut Assists, ctx: &AssistContext) -> Op } fn validate_method_call_expr( + ctx: &AssistContext, expr: ast::MethodCallExpr, ) -> Option<(ast::Expr, ast::Expr)> { let name_ref = expr.name_ref()?; + if !name_ref.syntax().text_range().contains_range(ctx.selection_trimmed()) { + cov_mark::hit!(test_for_each_not_applicable_invalid_cursor_pos); + return None; + } if name_ref.text() != "for_each" { return None; } + let receiver = expr.receiver()?; let expr = ast::Expr::MethodCallExpr(expr); - + Some((expr, receiver)) } From 9ee6ff4450ccc1f6a33bb934c29ca5430e3b0d6c Mon Sep 17 00:00:00 2001 From: AppCoder1234 Date: Wed, 20 Apr 2022 16:09:58 +0200 Subject: [PATCH 4/4] Detect and refactor unoptimized sum() and all() --- .../src/handlers/lambdify_for_each.rs | 224 +++++++++++++++++- crates/ide_assists/src/lib.rs | 2 + crates/ide_assists/src/tests/generated.rs | 48 ++++ 3 files changed, 271 insertions(+), 3 deletions(-) diff --git a/crates/ide_assists/src/handlers/lambdify_for_each.rs b/crates/ide_assists/src/handlers/lambdify_for_each.rs index 4cd066701acb..50c3ce91aec5 100644 --- a/crates/ide_assists/src/handlers/lambdify_for_each.rs +++ b/crates/ide_assists/src/handlers/lambdify_for_each.rs @@ -1,7 +1,8 @@ use stdx::format_to; +use hir::HirDisplay; use syntax::{ SyntaxKind, - ast::{self, edit_in_place::Indent, HasArgList, Pat, Expr}, + ast::{self, edit_in_place::Indent, HasArgList, Pat, Expr, BinaryOp, ArithOp}, AstNode, }; use ide_db::helpers::node_ext::walk_pat; @@ -115,6 +116,181 @@ pub(crate) fn convert_if_to_filter(acc: &mut Assists, ctx: &AssistContext) -> Op ) } +// Assist: convert_sum_call +// +// Converts a sum into a sum(). +// ``` +// # //- minicore: iterators +// # use core::iter; +// fn main() { +// let it = core::iter::repeat(92); +// let mut val: usize = 0; +// it.for_each$0(|x| val += x); +// } +// ``` +// -> +// ``` +// # use core::iter; +// fn main() { +// let it = core::iter::repeat(92); +// let mut val: usize = 0; +// val += it.sum::(); +// } +// ``` +pub(crate) fn convert_sum_call(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { + let method = ctx.find_node_at_offset::()?; + + let closure = match method.arg_list()?.args().next()? { + ast::Expr::ClosureExpr(expr) => expr, + _ => return None, + }; + + let (method, receiver) = validate_method_call_expr(ctx, method)?; + + let param_list = closure.param_list()?; + let param = param_list.params().next()?.pat()?; + let body = closure.body()?; + + let range = method.syntax().text_range(); + let module = ctx.sema.scope(param.syntax()).module()?; + + let binexpr = match body.clone() { + Expr::BinExpr(expr) => expr, + Expr::BlockExpr(block) => { + let mut stmts = block.statements(); + let fst_stmt = stmts.next()?; + continue_iff(stmts.next().is_none())?; // Only one statement + // First statement is an expression... + let expr_stmt = match fst_stmt { + ast::Stmt::ExprStmt(expr_stmt) => expr_stmt, + _ => return None, + }; + + // ...and even a binary expr... + let expr = expr_stmt.expr()?; + let my_bin_expr = match expr { + ast::Expr::BinExpr(my_bin_expr) => my_bin_expr, + _ => return None, + }; + my_bin_expr + }, + _ => return None, + }; + let op = match binexpr.op_kind()? { + BinaryOp::Assignment { op } => op?, + _ => return None, + }; + match op { + ArithOp::Add => (), + _ => return None, + } + continue_iff(format!("{}", binexpr.rhs()?) == format!("{}", param))?; + let sum = binexpr.lhs()?; + + let ty = ctx.sema.type_of_pat(¶m)?.adjusted(); + + // Fully unresolved or unnameable types can't be annotated + if (ty.contains_unknown() && ty.type_arguments().count() == 0) || ty.is_closure() { + return None; + } + + let inferred_type = ty.display_source_code(ctx.db(), module.into()).ok()?; + + acc.add( + AssistId("convert_sum_call", AssistKind::RefactorRewrite), + "Replace this sum in disguise with a `sum()` call", + range, + |builder| { + let mut buf = String::new(); + format_to!(buf, "{} += {}.sum::<{}>()", sum, receiver, inferred_type); + builder.replace(range, buf) + }, + ) +} + +// Assist: convert_all_call +// +// Replace with an all() call when possible. +// ``` +// # //- minicore: iterators +// # use core::iter; +// fn main() { +// let it = core::iter::repeat(92); +// let mut val: usize = 0; +// it.for_each$0(|x| val &= x > 0); +// } +// ``` +// -> +// ``` +// # use core::iter; +// fn main() { +// let it = core::iter::repeat(92); +// let mut val: usize = 0; +// val &= it.all(|&x| x > 0); +// } +// ``` +pub(crate) fn convert_all_call(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { + let method = ctx.find_node_at_offset::()?; + + let closure = match method.arg_list()?.args().next()? { + ast::Expr::ClosureExpr(expr) => expr, + _ => return None, + }; + + let (method, receiver) = validate_method_call_expr(ctx, method)?; + + let param_list = closure.param_list()?; + let param = param_list.params().next()?.pat()?; + let body = closure.body()?; + + let range = method.syntax().text_range(); + + let binexpr = match body.clone() { + Expr::BinExpr(expr) => expr, + Expr::BlockExpr(block) => { + let mut stmts = block.statements(); + let fst_stmt = stmts.next()?; + continue_iff(stmts.next().is_none())?; // Only one statement + // First statement is an expression... + let expr_stmt = match fst_stmt { + ast::Stmt::ExprStmt(expr_stmt) => expr_stmt, + _ => return None, + }; + + // ...and even a binary expr... + let expr = expr_stmt.expr()?; + let my_bin_expr = match expr { + ast::Expr::BinExpr(my_bin_expr) => my_bin_expr, + _ => return None, + }; + my_bin_expr + }, + _ => return None, + }; + let op = match binexpr.op_kind()? { + BinaryOp::Assignment { op } => op?, + _ => return None, + }; + match op { + ArithOp::BitAnd => (), + _ => return None, + } + let rhs = binexpr.rhs()?; + let sum = binexpr.lhs()?; + + acc.add( + AssistId("convert_all_call", AssistKind::RefactorRewrite), + "Replace this with a `all()` call", + range, + |builder| { + let mut buf = String::new(); + format_to!(buf, "{} &= {}.all(|&{param}| {})", sum, receiver, rhs); + builder.replace(range, buf) + }, + ) +} + + fn validate_method_call_expr( ctx: &AssistContext, expr: ast::MethodCallExpr, @@ -152,7 +328,7 @@ mod tests { r#" fn main() { let it = core::iter::repeat((92,42)); - it.for_each$0(|(mut i,mut j)| { + it.for_each$0(|(mut i,j)| { if (i*j)%3 == 2 { i *= 2; }; @@ -161,9 +337,51 @@ fn main() { r#" fn main() { let it = core::iter::repeat((92,42)); - it.filter(|&(i,j)| (i*j)%3 == 2).for_each(|(mut i,mut j)| { + it.filter(|&(i,j)| (i*j)%3 == 2).for_each(|(mut i,j)| { i *= 2; }); +}"#, + ) + } + + #[test] + fn add_sum_call() { + check_assist( + convert_sum_call, + r#" +fn main() { + let it = core::iter::repeat(92); + let mut a: usize = 0; + it.for_each$0(|x| { + a += x; + }); +}"#, + r#" +fn main() { + let it = core::iter::repeat(92); + let mut a: usize = 0; + a += it.sum::(); +}"#, + ) + } + + #[test] + fn add_all_call() { + check_assist( + convert_all_call, + r#" +fn main() { + let it = core::iter::repeat(92); + let mut a = true; + it.for_each$0(|x| { + a &= x > 0; + }); +}"#, + r#" +fn main() { + let it = core::iter::repeat(92); + let mut a = true; + a &= it.all(|&x| x > 0); }"#, ) } diff --git a/crates/ide_assists/src/lib.rs b/crates/ide_assists/src/lib.rs index b16e48d92b00..15976b1c4ea5 100644 --- a/crates/ide_assists/src/lib.rs +++ b/crates/ide_assists/src/lib.rs @@ -236,6 +236,8 @@ mod handlers { introduce_named_lifetime::introduce_named_lifetime, invert_if::invert_if, lambdify_for_each::convert_if_to_filter, + lambdify_for_each::convert_sum_call, + lambdify_for_each::convert_all_call, merge_imports::merge_imports, merge_match_arms::merge_match_arms, move_bounds::move_bounds_to_where_clause, diff --git a/crates/ide_assists/src/tests/generated.rs b/crates/ide_assists/src/tests/generated.rs index 48d7abaf7b61..35c8789ccc37 100644 --- a/crates/ide_assists/src/tests/generated.rs +++ b/crates/ide_assists/src/tests/generated.rs @@ -230,6 +230,30 @@ pub(crate) fn frobnicate() {} ) } +#[test] +fn doctest_convert_all_call() { + check_doc_test( + "convert_all_call", + r#####" +//- minicore: iterators +use core::iter; +fn main() { + let it = core::iter::repeat(92); + let mut val: usize = 0; + it.for_each$0(|x| val &= x > 0); +} +"#####, + r#####" +use core::iter; +fn main() { + let it = core::iter::repeat(92); + let mut val: usize = 0; + val &= it.all(|&x| x > 0); +} +"#####, + ) +} + #[test] fn doctest_convert_bool_then_to_if() { check_doc_test( @@ -392,6 +416,30 @@ fn main() { ) } +#[test] +fn doctest_convert_sum_call() { + check_doc_test( + "convert_sum_call", + r#####" +//- minicore: iterators +use core::iter; +fn main() { + let it = core::iter::repeat(92); + let mut val: usize = 0; + it.for_each$0(|x| val += x); +} +"#####, + r#####" +use core::iter; +fn main() { + let it = core::iter::repeat(92); + let mut val: usize = 0; + val += it.sum::(); +} +"#####, + ) +} + #[test] fn doctest_convert_to_guarded_return() { check_doc_test(