Skip to content
Draft
1 change: 1 addition & 0 deletions compiler/rustc_middle/src/arena.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ macro_rules! arena_types {
[decode] specialization_graph: rustc_middle::traits::specialization_graph::Graph,
[] crate_inherent_impls: rustc_middle::ty::CrateInherentImpls,
[] hir_owner_nodes: rustc_hir::OwnerNodes<'tcx>,
[] thir_pats: rustc_middle::thir::Pat<'tcx>,
]);
)
}
Expand Down
4 changes: 4 additions & 0 deletions compiler/rustc_mir_build/src/builder/expr/as_place.rs
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,10 @@ impl<'tcx> PlaceBuilder<'tcx> {
&self.projection
}

pub(crate) fn projection_mut(&mut self) -> &mut [PlaceElem<'tcx>] {
&mut self.projection
}

pub(crate) fn field(self, f: FieldIdx, ty: Ty<'tcx>) -> Self {
self.project(PlaceElem::Field(f, ty))
}
Expand Down
295 changes: 279 additions & 16 deletions compiler/rustc_mir_build/src/builder/matches/match_pair.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
use std::ops;

use either::Either;
use rustc_middle::bug;
use rustc_middle::mir::*;
use rustc_middle::thir::{self, *};
use rustc_middle::ty::{self, Ty, TypeVisitableExt};

use crate::builder::Builder;
use crate::builder::expr::as_place::{PlaceBase, PlaceBuilder};
use crate::builder::matches::util::Range;
use crate::builder::matches::{FlatPat, MatchPairTree, TestCase};

impl<'a, 'tcx> Builder<'a, 'tcx> {
Expand Down Expand Up @@ -33,6 +38,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
/// Used internally by [`MatchPairTree::for_pattern`].
fn prefix_slice_suffix<'pat>(
&mut self,
top_pattern: &'pat Pat<'tcx>,
match_pairs: &mut Vec<MatchPairTree<'pat, 'tcx>>,
place: &PlaceBuilder<'tcx>,
prefix: &'pat [Box<Pat<'tcx>>],
Expand All @@ -54,11 +60,12 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
((prefix.len() + suffix.len()).try_into().unwrap(), false)
};

match_pairs.extend(prefix.iter().enumerate().map(|(idx, subpattern)| {
let elem =
ProjectionElem::ConstantIndex { offset: idx as u64, min_length, from_end: false };
MatchPairTree::for_pattern(place.clone_project(elem), subpattern, self)
}));
if !prefix.is_empty() {
let bounds = Range::from_start(0..prefix.len() as u64);
let subpattern = bounds.apply(prefix);
self.build_slice_branch(bounds, place, top_pattern, subpattern, min_length)
.for_each(|pair| match_pairs.push(pair));
}

if let Some(subslice_pat) = opt_slice {
let suffix_len = suffix.len() as u64;
Expand All @@ -70,16 +77,258 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
match_pairs.push(MatchPairTree::for_pattern(subslice, subslice_pat, self));
}

match_pairs.extend(suffix.iter().rev().enumerate().map(|(idx, subpattern)| {
let end_offset = (idx + 1) as u64;
let elem = ProjectionElem::ConstantIndex {
offset: if exact_size { min_length - end_offset } else { end_offset },
min_length,
from_end: !exact_size,
if !suffix.is_empty() {
let bounds = Range::from_end(0..suffix.len() as u64);
let subpattern = bounds.apply(suffix);
self.build_slice_branch(bounds, place, top_pattern, subpattern, min_length)
.for_each(|pair| match_pairs.push(pair));
}
}

// Traverses either side of a slice pattern (prefix/suffix) and yields an iterator of `MatchPairTree`s
// to cover all it's constant and non-constant subpatterns.
fn build_slice_branch<'pat, 'b>(
&'b mut self,
bounds: Range,
place: &'b PlaceBuilder<'tcx>,
top_pattern: &'pat Pat<'tcx>,
pattern: &'pat [Box<Pat<'tcx>>],
min_length: u64,
) -> impl Iterator<Item = MatchPairTree<'pat, 'tcx>> + use<'a, 'tcx, 'pat, 'b> {
let entries = self.find_const_groups(pattern);

entries.into_iter().map(move |entry| {
// Common case handler for both non-constant and constant subpatterns not in a range.
let mut build_single = |idx| {
let subpattern = &pattern[idx as usize];
let place = place.clone_project(ProjectionElem::ConstantIndex {
offset: bounds.shift_idx(idx),
min_length: pattern.len() as u64,
from_end: bounds.from_end,
});

MatchPairTree::for_pattern(place, subpattern, self)
};
let place = place.clone_project(elem);
MatchPairTree::for_pattern(place, subpattern, self)
}));

match entry {
Either::Right(range) if range.end - range.start > 1 => {
// Figure out which subslice of our already sliced pattern we're looking at.
let subpattern = &pattern[range.start as usize..range.end as usize];
let elem_ty = subpattern[0].ty;

// Right, we 've found a group of constant patterns worth grouping for later.
// We'll collect all the leaves we can find and create a single `ValTree` out of them.
let valtree = self.simplify_const_pattern_slice_into_valtree(subpattern);
self.valtree_to_match_pair(
top_pattern,
valtree,
place.clone(),
elem_ty,
bounds.shift_range(range),
min_length,
)
}
Either::Right(range) => build_single(range.start),
Either::Left(idx) => build_single(idx),
}
})
}

// Given a partial view of the elements in a slice pattern, returns a list
// with left denoting non-constant element indices and right denoting ranges of constant elements.
fn find_const_groups(&self, pattern: &[Box<Pat<'tcx>>]) -> Vec<Either<u64, ops::Range<u64>>> {
let mut entries = Vec::new();
let mut current_seq_start = None;

for (idx, pat) in pattern.iter().enumerate() {
if self.is_constant_pattern(pat) {
if current_seq_start.is_none() {
current_seq_start = Some(idx as u64);
} else {
continue;
}
} else {
if let Some(start) = current_seq_start {
entries.push(Either::Right(start..idx as u64));
current_seq_start = None;
}
entries.push(Either::Left(idx as u64));
}
}

if let Some(start) = current_seq_start {
entries.push(Either::Right(start..pattern.len() as u64));
}

entries
}

// Checks if a pattern is constant and represented by a single scalar leaf.
fn is_constant_pattern(&self, pat: &Pat<'tcx>) -> bool {
if let PatKind::Constant { value } = pat.kind
&& let Const::Ty(_, const_) = value
&& let ty::ConstKind::Value(cv) = const_.kind()
&& let ty::ValTree::Leaf(_) = cv.valtree
{
true
} else {
false
}
}

// Extract the `ValTree` from a constant pattern.
// You must ensure that the pattern is a constant pattern before calling this function or it will panic.
fn extract_leaf(&self, pat: &Pat<'tcx>) -> ty::ValTree<'tcx> {
if let PatKind::Constant { value } = pat.kind
&& let Const::Ty(_, const_) = value
&& let ty::ConstKind::Value(cv) = const_.kind()
&& matches!(cv.valtree, ty::ValTree::Leaf(_))
{
cv.valtree
} else {
bug!("expected constant pattern, got {:?}", pat)
}
}

// Simplifies a slice of constant patterns into a single flattened `ValTree`.
fn simplify_const_pattern_slice_into_valtree(
&self,
subslice: &[Box<Pat<'tcx>>],
) -> ty::ValTree<'tcx> {
let leaves = subslice.iter().map(|p| self.extract_leaf(p));
let interned = self.tcx.arena.alloc_from_iter(leaves);
ty::ValTree::Branch(interned)
}

// Given a `ValTree` representing a slice of constant patterns, returns a `MatchPairTree`
// representing the slice pattern, providing as much info about subsequences in the slice as possible
// to later lowering stages.
fn valtree_to_match_pair<'pat>(
&mut self,
source_pattern: &'pat Pat<'tcx>,
valtree: ty::ValTree<'tcx>,
place: PlaceBuilder<'tcx>,
elem_ty: Ty<'tcx>,
range: Range,
min_length: u64,
) -> MatchPairTree<'pat, 'tcx> {
let tcx = self.tcx;
let leaves = match valtree {
ty::ValTree::Leaf(_) => bug!("expected branch, got leaf"),
ty::ValTree::Branch(leaves) => leaves,
};

assert!(range.len() == leaves.len() as u64);
let mut subpairs = Vec::new();
let mut were_merged = 0;

if elem_ty == tcx.types.u8 {
let leaf_bits = |leaf: ty::ValTree<'tcx>| match leaf {
ty::ValTree::Leaf(scalar) => scalar.to_u8(),
_ => bug!("found unflatted valtree"),
};

let mut fuse_group = |first_idx, len| {
were_merged += len;

let data = leaves[first_idx..first_idx + len]
.iter()
.rev()
.copied()
.map(leaf_bits)
.fold(0u32, |acc, x| (acc << 8) | u32::from(x));

let fused_ty = match len {
2 => tcx.types.u16,
3 | 4 => tcx.types.u32,
_ => unreachable!(),
};

let scalar = match len {
2 => ty::ScalarInt::from(data as u16),
3 | 4 => ty::ScalarInt::from(data),
_ => unreachable!(),
};

let valtree = ty::ValTree::Leaf(scalar);
let ty_const =
ty::Const::new(tcx, ty::ConstKind::Value(ty::Value { ty: fused_ty, valtree }));

let value = Const::Ty(fused_ty, ty_const);
let test_case = TestCase::FusedConstant { value, fused: len as u64 };

let pattern = tcx.arena.alloc(Pat {
ty: fused_ty,
span: source_pattern.span,
kind: PatKind::Constant { value },
});

let place = place
.clone_project(ProjectionElem::ConstantIndex {
offset: range.shift_idx(first_idx as u64),
min_length,
from_end: range.from_end,
})
.to_place(self);

subpairs.push(MatchPairTree {
place: Some(place),
test_case,
subpairs: Vec::new(),
pattern,
});
};

let indices = |group_size, skip| {
(skip..usize::MAX)
.take_while(move |i| i * group_size + (group_size - 1) < leaves.len())
};

let mut skip = 0;
for i in (2..=4).rev() {
for idx in indices(i, skip) {
fuse_group(idx * i, i);
skip += i;
}
}
}

for (idx, leaf) in leaves.iter().enumerate().skip(were_merged) {
let ty_const = ty::Const::new(
tcx,
ty::ConstKind::Value(ty::Value { ty: elem_ty, valtree: *leaf }),
);
let value = Const::Ty(elem_ty, ty_const);
let test_case = TestCase::Constant { value };

let pattern = tcx.arena.alloc(Pat {
ty: elem_ty,
span: source_pattern.span,
kind: PatKind::Constant { value },
});

let place = place
.clone_project(ProjectionElem::ConstantIndex {
offset: range.start + idx as u64,
min_length,
from_end: range.from_end,
})
.to_place(self);

subpairs.push(MatchPairTree {
place: Some(place),
test_case,
subpairs: Vec::new(),
pattern,
});
}

MatchPairTree {
place: None,
test_case: TestCase::Irrefutable { binding: None, ascription: None },
subpairs,
pattern: source_pattern,
}
}
}

Expand Down Expand Up @@ -192,11 +441,25 @@ impl<'pat, 'tcx> MatchPairTree<'pat, 'tcx> {
}

PatKind::Array { ref prefix, ref slice, ref suffix } => {
cx.prefix_slice_suffix(&mut subpairs, &place_builder, prefix, slice, suffix);
cx.prefix_slice_suffix(
pattern,
&mut subpairs,
&place_builder,
prefix,
slice,
suffix,
);
default_irrefutable()
}
PatKind::Slice { ref prefix, ref slice, ref suffix } => {
cx.prefix_slice_suffix(&mut subpairs, &place_builder, prefix, slice, suffix);
cx.prefix_slice_suffix(
pattern,
&mut subpairs,
&place_builder,
prefix,
slice,
suffix,
);

if prefix.is_empty() && slice.is_some() && suffix.is_empty() {
default_irrefutable()
Expand Down
3 changes: 2 additions & 1 deletion compiler/rustc_mir_build/src/builder/matches/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1238,6 +1238,7 @@ enum TestCase<'pat, 'tcx> {
Irrefutable { binding: Option<Binding<'tcx>>, ascription: Option<Ascription<'tcx>> },
Variant { adt_def: ty::AdtDef<'tcx>, variant_index: VariantIdx },
Constant { value: mir::Const<'tcx> },
FusedConstant { value: mir::Const<'tcx>, fused: u64 },
Range(&'pat PatRange<'tcx>),
Slice { len: usize, variable_length: bool },
Deref { temp: Place<'tcx>, mutability: Mutability },
Expand Down Expand Up @@ -1304,7 +1305,7 @@ enum TestKind<'tcx> {
///
/// The test's target values are not stored here; instead they are extracted
/// from the [`TestCase`]s of the candidates participating in the test.
SwitchInt,
SwitchInt { fused: u64 },

/// Test whether a `bool` is `true` or `false`.
If,
Expand Down
Loading
Loading