Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

129 changes: 122 additions & 7 deletions compiler-core/checking/src/check/convert.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::iter;

use itertools::Itertools;
use smol_str::SmolStr;

Expand Down Expand Up @@ -39,13 +41,13 @@ where
};
state.storage.intern(Type::Constructor(file_id, type_id))
}
lowering::TypeKind::Forall { bindings, type_ } => {
lowering::TypeKind::Forall { bindings, inner } => {
let binders = bindings
.iter()
.map(|binding| convert_forall_binding(state, context, binding))
.collect_vec();

let inner = type_.map_or(default, |id| type_to_core(state, context, id));
let inner = inner.map_or(default, |id| type_to_core(state, context, id));

let forall = binders
.into_iter()
Expand All @@ -65,7 +67,7 @@ where
lowering::TypeKind::String => default,
lowering::TypeKind::Variable { name, resolution } => {
let Some(resolution) = resolution else {
let name = name.clone().unwrap_or(INVALID_NAME);
let name = name.clone().unwrap_or(MISSING_NAME);
let kind = Variable::Free(name);
return state.storage.intern(Type::Variable(kind));
};
Expand All @@ -90,9 +92,122 @@ where
}
}

const INVALID_NAME: SmolStr = SmolStr::new_inline("<invalid>");
/// A variant of [`type_to_core`] for use with signature declarations.
///
/// Unlike the regular [`type_to_core`], this function does not call
/// [`CheckState::unbind`] after each [`lowering::TypeKind::Forall`]
/// node. This allows type variables to be scoped for the entire
/// declaration group rather than just the type signature.
pub fn signature_type_to_core<Q>(
state: &mut CheckState,
context: &CheckContext<Q>,
id: lowering::TypeId,
) -> TypeId
where
Q: ExternalQueries,
{
let default = context.prim.unknown;

let Some(kind) = context.lowered.info.get_type_kind(id) else {
return default;
};

match kind {
lowering::TypeKind::Forall { bindings, inner } => {
let binders = bindings
.iter()
.map(|binding| convert_forall_binding(state, context, binding))
.collect_vec();

let inner = inner.map_or(default, |id| type_to_core(state, context, id));

binders
.into_iter()
.rfold(inner, |inner, binder| state.storage.intern(Type::Forall(binder, inner)))
}

lowering::TypeKind::Parenthesized { parenthesized } => {
parenthesized.map(|id| signature_type_to_core(state, context, id)).unwrap_or(default)
}

_ => type_to_core(state, context, id),
}
}

pub struct InspectSignature {
pub variables: Vec<ForallBinder>,
pub arguments: Vec<TypeId>,
pub result: TypeId,
}

pub fn inspect_signature<Q>(
state: &mut CheckState,
context: &CheckContext<Q>,
id: lowering::TypeId,
) -> InspectSignature
where
Q: ExternalQueries,
{
let unknown = || {
let variables = [].into();
let arguments = [].into();
let result = context.prim.unknown;
InspectSignature { variables, arguments, result }
};

let Some(kind) = context.lowered.info.get_type_kind(id) else {
return unknown();
};

match kind {
lowering::TypeKind::Forall { bindings, inner } => {
let variables = bindings
.iter()
.map(|binding| convert_forall_binding(state, context, binding))
.collect();

let inner = inner.map_or(context.prim.unknown, |id| type_to_core(state, context, id));
let (arguments, result) = signature_components(state, inner);

InspectSignature { variables, arguments, result }
}

lowering::TypeKind::Parenthesized { parenthesized } => {
parenthesized.map(|id| inspect_signature(state, context, id)).unwrap_or_else(unknown)
}

_ => {
let variables = [].into();

let id = type_to_core(state, context, id);
let (arguments, result) = signature_components(state, id);

InspectSignature { variables, arguments, result }
}
}
}

fn signature_components(state: &mut CheckState, id: TypeId) -> (Vec<TypeId>, TypeId) {
let mut components = iter::successors(Some(id), |&id| match state.storage[id] {
Type::Function(_, id) => Some(id),
_ => None,
})
.map(|id| match state.storage[id] {
Type::Function(id, _) => id,
_ => id,
})
.collect_vec();

let Some(id) = components.pop() else {
unreachable!("invariant violated: expected non-empty components");
};

(components, id)
}

const MISSING_NAME: SmolStr = SmolStr::new_inline("<MissingName>");

fn convert_forall_binding<Q>(
pub fn convert_forall_binding<Q>(
state: &mut CheckState,
context: &CheckContext<Q>,
binding: &lowering::TypeVariableBinding,
Expand All @@ -101,11 +216,11 @@ where
Q: ExternalQueries,
{
let visible = binding.visible;
let name = binding.name.clone().unwrap_or(INVALID_NAME);
let name = binding.name.clone().unwrap_or(MISSING_NAME);

let kind = match binding.kind {
Some(id) => type_to_core(state, context, id),
None => state.fresh_unification(context),
None => state.fresh_unification_type(context),
};

let level = state.bind_forall(binding.id, kind);
Expand Down
4 changes: 2 additions & 2 deletions compiler-core/checking/src/check/kind.rs
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ where

Type::Variable(ref variable) => match variable {
Variable::Implicit(_) => context.prim.unknown,
Variable::Skolem(_) => context.prim.unknown,
Variable::Skolem(_, kind) => *kind,
Variable::Bound(index) => {
let size = state.bound.size();

Expand Down Expand Up @@ -254,6 +254,6 @@ where
Q: ExternalQueries,
{
let (inferred_type, inferred_kind) = infer_surface_kind(state, context, id);
unification::unify(state, context, inferred_kind, kind);
unification::subsumes(state, context, inferred_kind, kind);
(inferred_type, inferred_kind)
}
5 changes: 4 additions & 1 deletion compiler-core/checking/src/check/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,11 @@ where
{
pub queries: &'a Q,
pub prim: PrimCore,

pub id: FileId,
pub indexed: Arc<IndexedModule>,
pub lowered: Arc<LoweredModule>,

pub prim_indexed: Arc<IndexedModule>,
}

Expand All @@ -116,7 +119,7 @@ where
let prim = PrimCore::collect(queries, state)?;
let prim_id = queries.prim_id();
let prim_indexed = queries.indexed(prim_id)?;
Ok(CheckContext { queries, prim, indexed, lowered, prim_indexed })
Ok(CheckContext { queries, prim, id, indexed, lowered, prim_indexed })
}
}

Expand Down
53 changes: 52 additions & 1 deletion compiler-core/checking/src/check/unification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,50 @@ pub mod level;
pub use context::*;

use crate::ExternalQueries;
use crate::check::{CheckContext, CheckState, kind};
use crate::check::{CheckContext, CheckState, kind, substitute};
use crate::core::{Type, TypeId, Variable, debruijn};

pub fn subsumes<Q>(
state: &mut CheckState,
context: &CheckContext<Q>,
t1: TypeId,
t2: TypeId,
) -> bool
where
Q: ExternalQueries,
{
let t1 = state.normalize_type(t1);
let t2 = state.normalize_type(t2);

let t1_core = state.storage[t1].clone();
let t2_core = state.storage[t2].clone();

match (t1_core, t2_core) {
(Type::Function(t1_argument, t1_result), Type::Function(t2_argument, t2_result)) => {
subsumes(state, context, t2_argument, t1_argument)
&& subsumes(state, context, t1_result, t2_result)
}

(_, Type::Forall(ref binder, inner)) => {
let v = Variable::Skolem(binder.level, binder.kind);
let t = state.storage.intern(Type::Variable(v));

let inner = substitute::substitute_bound(state, t, inner);
subsumes(state, context, t1, inner)
}

(Type::Forall(ref binder, inner), _) => {
let k = state.normalize_type(binder.kind);
let t = state.fresh_unification_kinded(k);

let inner = substitute::substitute_bound(state, t, inner);
subsumes(state, context, inner, t2)
}

_ => unify(state, context, t1, t2),
}
}

pub fn unify<Q>(state: &mut CheckState, context: &CheckContext<Q>, t1: TypeId, t2: TypeId) -> bool
where
Q: ExternalQueries,
Expand Down Expand Up @@ -39,6 +80,16 @@ where
solve(state, context, unification_id, t1).is_some()
}

(
&Type::Variable(Variable::Bound(t1_index)),
&Type::Variable(Variable::Bound(t2_index)),
) => t1_index == t2_index,

(
&Type::Variable(Variable::Skolem(t1_level, _)),
&Type::Variable(Variable::Skolem(t2_level, _)),
) => t1_level == t2_level,

_ => false,
}
}
Expand Down
2 changes: 1 addition & 1 deletion compiler-core/checking/src/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ pub struct ForallBinder {
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum Variable {
Implicit(debruijn::Level),
Skolem(debruijn::Level),
Skolem(debruijn::Level, TypeId),
Bound(debruijn::Index),
Free(SmolStr),
}
Expand Down
9 changes: 6 additions & 3 deletions compiler-core/checking/src/core/pretty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ fn traverse<'a, Q: ExternalQueries>(source: &mut TraversalSource<'a, Q>, id: Typ
}

let inner = traverse(source, inner);
write!(&mut buffer, " {inner}").unwrap();
write!(&mut buffer, ". {inner}").unwrap();

buffer
}
Expand All @@ -104,7 +104,7 @@ fn traverse<'a, Q: ExternalQueries>(source: &mut TraversalSource<'a, Q>, id: Typ

let result = traverse(source, result);
let arguments =
arguments.iter().rev().map(|argument| traverse(source, *argument)).join(" -> ");
arguments.iter().map(|argument| traverse(source, *argument)).join(" -> ");

format!("({arguments} -> {result})")
}
Expand Down Expand Up @@ -139,7 +139,10 @@ fn traverse<'a, Q: ExternalQueries>(source: &mut TraversalSource<'a, Q>, id: Typ

Type::Variable(ref variable) => match variable {
Variable::Implicit(level) => format!("{level}"),
Variable::Skolem(level) => format!("~{level}"),
Variable::Skolem(level, kind) => {
let kind = traverse(source, *kind);
format!("~{level} :: {kind}")
}
Variable::Bound(index) => format!("{index}"),
Variable::Free(name) => format!("{name}"),
},
Expand Down
Loading