From 9ee79745da7af2ff08df5c2fd1a85fc72c144ed1 Mon Sep 17 00:00:00 2001 From: Justin Garcia Date: Mon, 3 Nov 2025 18:26:07 +0800 Subject: [PATCH 1/9] Begin implementing kind checker for data declaration --- compiler-core/checking/src/check/convert.rs | 44 ++++++++- .../checking/src/check/unification.rs | 20 +++++ compiler-core/checking/src/core/pretty.rs | 2 +- compiler-core/checking/src/lib.rs | 89 +++++++++++++++++-- tests-integration/tests/checking.rs | 8 +- tests-integration/tests/lowering_scc.rs | 29 ++++++ .../lowering_scc__non_cycle_ordering.snap | 18 ++++ 7 files changed, 195 insertions(+), 15 deletions(-) create mode 100644 tests-integration/tests/snapshots/lowering_scc__non_cycle_ordering.snap diff --git a/compiler-core/checking/src/check/convert.rs b/compiler-core/checking/src/check/convert.rs index e3ac1b3a..980abbeb 100644 --- a/compiler-core/checking/src/check/convert.rs +++ b/compiler-core/checking/src/check/convert.rs @@ -90,9 +90,51 @@ where } } +/// A variant of [`type_to_core`] for use with signature declarations. +/// +/// Unlike the regular [`type_to_core`], this function does not call +/// [`CheckState::unbind`] after each [`lowering::TypeKind::Forall`] +/// node. This allows type variables to be scoped for the entire +/// declaration group rather than just the type signature. +pub fn signature_type_to_core( + state: &mut CheckState, + context: &CheckContext, + id: lowering::TypeId, +) -> TypeId +where + Q: ExternalQueries, +{ + let default = context.prim.unknown; + + let Some(kind) = context.lowered.info.get_type_kind(id) else { + return default; + }; + + match kind { + lowering::TypeKind::Forall { bindings, type_ } => { + let binders = bindings + .iter() + .map(|binding| convert_forall_binding(state, context, binding)) + .collect_vec(); + + let inner = type_.map_or(default, |id| type_to_core(state, context, id)); + + binders + .into_iter() + .rfold(inner, |inner, binder| state.storage.intern(Type::Forall(binder, inner))) + } + + lowering::TypeKind::Parenthesized { parenthesized } => { + parenthesized.map(|id| signature_type_to_core(state, context, id)).unwrap_or(default) + } + + _ => type_to_core(state, context, id), + } +} + const INVALID_NAME: SmolStr = SmolStr::new_inline(""); -fn convert_forall_binding( +pub fn convert_forall_binding( state: &mut CheckState, context: &CheckContext, binding: &lowering::TypeVariableBinding, diff --git a/compiler-core/checking/src/check/unification.rs b/compiler-core/checking/src/check/unification.rs index 3cc05c23..a4eebdfa 100644 --- a/compiler-core/checking/src/check/unification.rs +++ b/compiler-core/checking/src/check/unification.rs @@ -7,6 +7,26 @@ use crate::ExternalQueries; use crate::check::{CheckContext, CheckState, kind}; use crate::core::{Type, TypeId, Variable, debruijn}; +pub fn subsumes( + state: &mut CheckState, + context: &CheckContext, + t1: TypeId, + t2: TypeId, +) -> bool +where + Q: ExternalQueries, +{ + let t1 = state.normalize_type(t1); + let t2 = state.normalize_type(t2); + match (&state.storage[t1], &state.storage[t2]) { + (&Type::Function(t1_argument, t1_result), &Type::Function(t2_argument, t2_result)) => { + subsumes(state, context, t2_argument, t1_argument) + && subsumes(state, context, t1_result, t2_result) + } + _ => unify(state, context, t1, t2), + } +} + pub fn unify(state: &mut CheckState, context: &CheckContext, t1: TypeId, t2: TypeId) -> bool where Q: ExternalQueries, diff --git a/compiler-core/checking/src/core/pretty.rs b/compiler-core/checking/src/core/pretty.rs index fd8da0d2..fa1ccce3 100644 --- a/compiler-core/checking/src/core/pretty.rs +++ b/compiler-core/checking/src/core/pretty.rs @@ -89,7 +89,7 @@ fn traverse<'a, Q: ExternalQueries>(source: &mut TraversalSource<'a, Q>, id: Typ } let inner = traverse(source, inner); - write!(&mut buffer, " {inner}").unwrap(); + write!(&mut buffer, ". {inner}").unwrap(); buffer } diff --git a/compiler-core/checking/src/lib.rs b/compiler-core/checking/src/lib.rs index ab241666..e900e37c 100644 --- a/compiler-core/checking/src/lib.rs +++ b/compiler-core/checking/src/lib.rs @@ -8,12 +8,14 @@ use std::sync::Arc; use building_types::{QueryProxy, QueryResult}; use files::FileId; use indexing::{IndexedModule, TermItemId, TypeItemId}; -use lowering::{LoweredModule, Scc}; +use itertools::Itertools; +use lowering::{DataIr, LoweredModule, Scc, TermItemIr, TypeItemIr, TypeVariableBinding}; use resolving::ResolvedModule; use rustc_hash::FxHashMap; -use crate::check::{CheckContext, CheckState, kind, transfer}; -use crate::core::{ForallBinder, Variable, debruijn}; +use crate::check::kind::check_surface_kind; +use crate::check::{CheckContext, CheckState, convert, kind, transfer}; +use crate::core::{ForallBinder, Variable, debruijn, pretty}; pub trait ExternalQueries: QueryProxy< @@ -85,15 +87,47 @@ where { let Some(item) = context.lowered.info.get_type_item(item_id) else { return }; match item { - lowering::TypeItemIr::DataGroup { .. } => (), + TypeItemIr::DataGroup { signature, data, .. } => { + let signature_type = signature + .map(|signature| convert::signature_type_to_core(state, context, signature)); + + if let Some(DataIr { variables }) = data { + let inferred_type = create_type_declaration_kind(state, context, &variables); + for constructor_id in context.indexed.pairs.data_constructors(item_id) { + let Some(TermItemIr::Constructor { arguments }) = + context.lowered.info.get_term_item(constructor_id) + else { + continue; + }; + for argument in arguments.iter() { + let (inferred_type, inferred_kind) = + check_surface_kind(state, context, *argument, context.prim.t); + { + let inferred_type = pretty::print_local(state, context, inferred_type); + let inferred_kind = pretty::print_local(state, context, inferred_kind); + eprintln!("{inferred_type} :: {inferred_kind}") + } + } + } + + { + if let Some(signature_type) = signature_type { + let signature_type = pretty::print_local(state, context, signature_type); + eprintln!("{signature_type}"); + } + let inferred_type = pretty::print_local(state, context, inferred_type); + eprintln!("{inferred_type}"); + } + } + } - lowering::TypeItemIr::NewtypeGroup { .. } => (), + TypeItemIr::NewtypeGroup { .. } => (), - lowering::TypeItemIr::SynonymGroup { .. } => (), + TypeItemIr::SynonymGroup { .. } => (), - lowering::TypeItemIr::ClassGroup { .. } => (), + TypeItemIr::ClassGroup { .. } => (), - lowering::TypeItemIr::Foreign { signature, .. } => { + TypeItemIr::Foreign { signature, .. } => { let Some(signature_id) = signature else { return }; let (inferred_type, _) = kind::check_surface_kind(state, context, *signature_id, context.prim.t); @@ -101,10 +135,47 @@ where state.checked.types.insert(item_id, inferred_type); } - lowering::TypeItemIr::Operator { .. } => (), + TypeItemIr::Operator { .. } => (), } } +fn create_type_declaration_kind( + state: &mut CheckState, + context: &CheckContext, + bindings: &[TypeVariableBinding], +) -> TypeId +where + Q: ExternalQueries, +{ + let binders = bindings + .iter() + .map(|binding| convert::convert_forall_binding(state, context, binding)) + .collect_vec(); + + // Build the function type for the type declaration e.g. + // + // ```purescript + // data Maybe a = Just a | Nothing + // ``` + // + // function_type := a -> Type + let size = state.bound.size(); + let function_type = binders.iter().rfold(context.prim.t, |result, binder| { + let index = binder.level.to_index(size).unwrap_or_else(|| { + unreachable!("invariant violated: invalid {} for {size}", binder.level) + }); + let variable = state.storage.intern(Type::Variable(Variable::Bound(index))); + state.storage.intern(Type::Function(variable, result)) + }); + + // Qualify the type variables in the function type e.g. + // + // forall (a :: Type). a -> Type + binders + .into_iter() + .rfold(function_type, |inner, binder| state.storage.intern(Type::Forall(binder, inner))) +} + fn prim_check_module( queries: &impl ExternalQueries, file_id: FileId, diff --git a/tests-integration/tests/checking.rs b/tests-integration/tests/checking.rs index 46044613..a62affb8 100644 --- a/tests-integration/tests/checking.rs +++ b/tests-integration/tests/checking.rs @@ -275,13 +275,13 @@ fn test_quantify_multiple_scoped() { fn test_manual() { let (engine, id) = empty_engine(); - engine.set_content(id, "module Main where\n\nforeign import data T :: Proxy 123"); + engine.set_content(id, "module Main where\n\ndata Either a b = Left a | Right b"); let resolved = engine.resolved(id).unwrap(); let checked = engine.checked(id).unwrap(); - let (_, id) = resolved.locals.lookup_type("T").unwrap(); - let id = checked.lookup_type(id).unwrap(); + // let (_, id) = resolved.locals.lookup_type("T").unwrap(); + // let id = checked.lookup_type(id).unwrap(); - eprintln!("{}", pretty::print_global(&engine, id)); + // eprintln!("{}", pretty::print_global(&engine, id)); } diff --git a/tests-integration/tests/lowering_scc.rs b/tests-integration/tests/lowering_scc.rs index 1aeeccff..9dbe8831 100644 --- a/tests-integration/tests/lowering_scc.rs +++ b/tests-integration/tests/lowering_scc.rs @@ -77,3 +77,32 @@ infix 5 type Add as + insta::assert_debug_snapshot!((terms, types)); } + +#[test] +fn test_non_cycle_ordering() {{ + let mut engine = QueryEngine::default(); + let mut files = Files::default(); + prim::configure(&mut engine, &mut files); + + let id = files.insert( + "Main.purs", + r#" +module Main where + +a _ = b 0 +b _ = c 0 +c _ = 0 +"#, + ); + let content = files.content(id); + + engine.set_content(id, content); + + let lowered = engine.lowered(id).unwrap(); + + let terms = &lowered.term_scc; + let types = &lowered.type_scc; + + insta::assert_debug_snapshot!((terms, types)); +} +} diff --git a/tests-integration/tests/snapshots/lowering_scc__non_cycle_ordering.snap b/tests-integration/tests/snapshots/lowering_scc__non_cycle_ordering.snap new file mode 100644 index 00000000..852dcea0 --- /dev/null +++ b/tests-integration/tests/snapshots/lowering_scc__non_cycle_ordering.snap @@ -0,0 +1,18 @@ +--- +source: tests-integration/tests/lowering_scc.rs +expression: "(terms, types)" +--- +( + [ + Base( + Idx::(2), + ), + Base( + Idx::(1), + ), + Base( + Idx::(0), + ), + ], + [], +) From 52ee6ae215e40cafb7527c278e1e76b0c05dfb9d Mon Sep 17 00:00:00 2001 From: Justin Garcia Date: Fri, 7 Nov 2025 00:48:31 +0800 Subject: [PATCH 2/9] Use symbol caching for workspace symbols --- Cargo.lock | 1 + compiler-lsp/analyzer/Cargo.toml | 3 + .../analyzer/src/completion/prelude.rs | 8 +- compiler-lsp/analyzer/src/definition.rs | 10 +- compiler-lsp/analyzer/src/locate.rs | 107 ++---- compiler-lsp/analyzer/src/locate/tests.rs | 315 ++++++++++++++++++ compiler-lsp/analyzer/src/symbols.rs | 145 ++++---- 7 files changed, 428 insertions(+), 161 deletions(-) create mode 100644 compiler-lsp/analyzer/src/locate/tests.rs diff --git a/Cargo.lock b/Cargo.lock index 46e45815..b2b4a252 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -40,6 +40,7 @@ dependencies = [ "building", "files", "indexing", + "insta", "itertools 0.14.0", "la-arena", "lexing", diff --git a/compiler-lsp/analyzer/Cargo.toml b/compiler-lsp/analyzer/Cargo.toml index 0d5564a5..ab9426f6 100644 --- a/compiler-lsp/analyzer/Cargo.toml +++ b/compiler-lsp/analyzer/Cargo.toml @@ -27,3 +27,6 @@ syntax = { version = "0.1.0", path = "../../compiler-core/syntax" } thiserror = "2.0.16" tracing = "0.1.41" url = "2.5.7" + +[dev-dependencies] +insta = "1.41.1" diff --git a/compiler-lsp/analyzer/src/completion/prelude.rs b/compiler-lsp/analyzer/src/completion/prelude.rs index a1ed3f9c..334a2d57 100644 --- a/compiler-lsp/analyzer/src/completion/prelude.rs +++ b/compiler-lsp/analyzer/src/completion/prelude.rs @@ -41,7 +41,7 @@ impl Context<'_, '_> { |cst| Some(cst.syntax().text_range()), )?; - let mut position = locate::offset_to_position(self.content, range.end()); + let mut position = locate::offset_to_position(self.content, range.end())?; position.line += 1; position.character = 0; @@ -226,7 +226,7 @@ impl CursorText { (None, None) => None, }; - let range = range.map(|range| locate::text_range_to_range(content, range)); + let range = range.and_then(|range| locate::text_range_to_range(content, range)); let text = match (prefix, name) { (None, None) => CursorText::None, (Some(p), None) => CursorText::Prefix(p), @@ -247,7 +247,7 @@ impl CursorText { let prefix = SmolStr::new(prefix); let range = token.text_range(); - let range = locate::text_range_to_range(content, range); + let range = locate::text_range_to_range(content, range)?; let range = Some(range); let text = CursorText::Prefix(prefix); @@ -275,7 +275,7 @@ impl CursorText { (None, None) => None, }; - let range = range.map(|range| locate::text_range_to_range(content, range)); + let range = range.map(|range| locate::text_range_to_range(content, range))?; let text = match (prefix, name) { (None, None) => CursorText::None, (Some(p), None) => CursorText::Prefix(p), diff --git a/compiler-lsp/analyzer/src/definition.rs b/compiler-lsp/analyzer/src/definition.rs index ba02d6d3..e620c45c 100644 --- a/compiler-lsp/analyzer/src/definition.rs +++ b/compiler-lsp/analyzer/src/definition.rs @@ -83,12 +83,14 @@ fn definition_module_name( let root = parsed.syntax_node(); let range = root.text_range(); - let start = locate::offset_to_position(&content, range.start()); - let end = locate::offset_to_position(&content, range.end()); - let uri = common::file_uri(engine, files, module_id)?; + let range_start = + locate::offset_to_position(&content, range.start()).ok_or(AnalyzerError::NonFatal)?; + let range_end = + locate::offset_to_position(&content, range.end()).ok_or(AnalyzerError::NonFatal)?; - let range = Range { start, end }; + let uri = common::file_uri(engine, files, module_id)?; + let range = Range { start: range_start, end: range_end }; Ok(Some(GotoDefinitionResponse::Scalar(Location { uri, range }))) } diff --git a/compiler-lsp/analyzer/src/locate.rs b/compiler-lsp/analyzer/src/locate.rs index 9f1c9059..acada7e5 100644 --- a/compiler-lsp/analyzer/src/locate.rs +++ b/compiler-lsp/analyzer/src/locate.rs @@ -24,7 +24,7 @@ pub fn position_to_offset(content: &str, position: Position) -> Option let col = if line_content.is_empty() { 0 } else { - let last_column = || line_content.chars().count() as u32; + let last_column = || line_content.len() as u32; line_content .char_indices() .nth(position.character as usize) @@ -36,27 +36,44 @@ pub fn position_to_offset(content: &str, position: Position) -> Option line_index.offset(line_col) } -pub fn offset_to_position(content: &str, offset: TextSize) -> Position { +pub fn offset_to_position(content: &str, offset: TextSize) -> Option { let line_index = LineIndex::new(content); + let LineCol { line, col } = line_index.line_col(offset); - Position { line, character: col } + + let line_text_range = line_index.line(line)?; + let line_content = &content[line_text_range]; + + let until_col = &line_content[..col as usize]; + let character = until_col.chars().count() as u32; + + Some(Position { line, character }) } -pub fn text_range_to_range(content: &str, range: TextRange) -> Range { +pub fn text_range_to_range(content: &str, range: TextRange) -> Option { let line_index = LineIndex::new(content); - let start = line_index.line_col(range.start()); - let start = Position { line: start.line, character: start.col }; + let calculate = |offset: TextSize| { + let LineCol { line, col } = line_index.line_col(offset); + + let line_text_range = line_index.line(line)?; + let line_content = &content[line_text_range]; + + let until_col = &line_content[..col as usize]; + let character = until_col.chars().count() as u32; + + Some(Position { line, character }) + }; - let end = line_index.line_col(range.end()); - let end = Position { line: end.line, character: end.col }; + let start = calculate(range.start())?; + let end = calculate(range.end())?; - Range { start, end } + Some(Range { start, end }) } pub fn syntax_range(content: &str, root: &SyntaxNode, ptr: &SyntaxNodePtr) -> Option { let range = AnnotationSyntaxRange::from_ptr(root, ptr); - range.syntax.map(|range| text_range_to_range(content, range)) + range.syntax.and_then(|range| text_range_to_range(content, range)) } type ModuleNamePtr = AstPtr; @@ -186,72 +203,4 @@ fn locate_between( } #[cfg(test)] -mod tests { - use async_lsp::lsp_types::Position; - use rowan::TextSize; - - use super::position_to_offset; - - #[test] - fn zero_on_blank_line() { - let content = ""; - let position = Position::new(0, 0); - - let offset = position_to_offset(content, position); - assert_eq!(offset, Some(TextSize::new(0))); - } - - #[test] - fn zero_or_lf_line() { - let content = "\n"; - let position = Position::new(0, 0); - - let offset = position_to_offset(content, position); - assert_eq!(offset, Some(TextSize::new(0))); - } - - #[test] - fn zero_or_crlf_line() { - let content = "\r\n"; - let position = Position::new(0, 0); - - let offset = position_to_offset(content, position); - assert_eq!(offset, Some(TextSize::new(0))); - } - - #[test] - fn last_on_line() { - let content = "abcdef"; - let position = Position::new(0, 6); - - let offset = position_to_offset(content, position); - assert_eq!(offset, Some(TextSize::new(6))); - } - - #[test] - fn last_on_line_clamp() { - let content = "abcdef"; - let position = Position::new(0, 600); - - let offset = position_to_offset(content, position); - assert_eq!(offset, Some(TextSize::new(6))); - } - - #[test] - fn last_on_lf_line() { - let content = "abcdef\n"; - let position = Position::new(0, 6); - - let offset = position_to_offset(content, position); - assert_eq!(offset, Some(TextSize::new(6))); - } - - #[test] - fn last_on_crlf_line_clamp() { - let content = "abcdef\r\n"; - let position = Position::new(0, 600); - - let offset = position_to_offset(content, position); - assert_eq!(offset, Some(TextSize::new(6))); - } -} +mod tests; diff --git a/compiler-lsp/analyzer/src/locate/tests.rs b/compiler-lsp/analyzer/src/locate/tests.rs new file mode 100644 index 00000000..c036fa91 --- /dev/null +++ b/compiler-lsp/analyzer/src/locate/tests.rs @@ -0,0 +1,315 @@ +use async_lsp::lsp_types::Position; +use rowan::TextSize; + +use crate::locate::position_to_offset; + +#[test] +fn zero_on_blank_line() { + let content = ""; + let position = Position::new(0, 0); + + let offset = position_to_offset(content, position); + assert_eq!(offset, Some(TextSize::new(0))); +} + +#[test] +fn zero_or_lf_line() { + let content = "\n"; + let position = Position::new(0, 0); + + let offset = position_to_offset(content, position); + assert_eq!(offset, Some(TextSize::new(0))); +} + +#[test] +fn zero_or_crlf_line() { + let content = "\r\n"; + let position = Position::new(0, 0); + + let offset = position_to_offset(content, position); + assert_eq!(offset, Some(TextSize::new(0))); +} + +#[test] +fn last_on_line() { + let content = "abcdef"; + let position = Position::new(0, 6); + + let offset = position_to_offset(content, position); + assert_eq!(offset, Some(TextSize::new(6))); +} + +#[test] +fn last_on_line_clamp() { + let content = "abcdef"; + let position = Position::new(0, 600); + + let offset = position_to_offset(content, position); + assert_eq!(offset, Some(TextSize::new(6))); +} + +#[test] +fn last_on_lf_line() { + let content = "abcdef\n"; + let position = Position::new(0, 6); + + let offset = position_to_offset(content, position); + assert_eq!(offset, Some(TextSize::new(6))); +} + +#[test] +fn last_on_crlf_line_clamp() { + let content = "abcdef\r\n"; + let position = Position::new(0, 600); + + let offset = position_to_offset(content, position); + assert_eq!(offset, Some(TextSize::new(6))); +} + +mod text_range_to_range { + use rowan::{TextRange, TextSize}; + + use crate::locate::{position_to_offset, text_range_to_range}; + + /// Extracts a slice from `content` using `_` anchors to mark + /// the start and end. The content is returned with the anchors + /// removed. + /// + /// Example: `"hello _world_"` -> ("hello world", TextRange(6, 11)) + fn extract_range(content: &str) -> (String, TextRange) { + let mut clean = String::new(); + let mut positions = Vec::new(); + let mut byte_offset = 0u32; + + for ch in content.chars() { + if ch == '_' { + positions.push(TextSize::new(byte_offset)); + } else { + clean.push(ch); + byte_offset += ch.len_utf8() as u32; + } + } + + assert_eq!(positions.len(), 2, "Expected exactly 2 '_' anchors for range"); + (clean, TextRange::new(positions[0], positions[1])) + } + + fn format_test_case(input: &str) -> String { + let (content, range) = extract_range(input); + + let lsp_range = text_range_to_range(&content, range).unwrap(); + let text_range = { + let start = position_to_offset(&content, lsp_range.start).unwrap(); + let end = position_to_offset(&content, lsp_range.end).unwrap(); + TextRange::new(start, end) + }; + + format!( + "Content: {}\nRange: {}:{} -> {}:{}\nRoundtrip: {}", + &content[range], + lsp_range.start.line, + lsp_range.start.character, + lsp_range.end.line, + lsp_range.end.character, + range == text_range, + ) + } + + #[test] + fn simple() { + insta::assert_snapshot!(format_test_case("_hello, world_"), @r" + Content: hello, world + Range: 0:0 -> 0:12 + Roundtrip: true + "); + } + + #[test] + fn partial() { + insta::assert_snapshot!(format_test_case("hello, _world_"), @r" + Content: world + Range: 0:7 -> 0:12 + Roundtrip: true + "); + } + + #[test] + fn unicode_full() { + insta::assert_snapshot!(format_test_case("_content ∷ Type_"), @r" + Content: content ∷ Type + Range: 0:0 -> 0:14 + Roundtrip: true + "); + } + + #[test] + fn unicode_partial() { + insta::assert_snapshot!(format_test_case("content _∷ Type_"), @r" + Content: ∷ Type + Range: 0:8 -> 0:14 + Roundtrip: true + "); + } + + #[test] + fn unicode_ending() { + insta::assert_snapshot!(format_test_case("_content ∷_ Type"), @r" + Content: content ∷ + Range: 0:0 -> 0:9 + Roundtrip: true + "); + } + + #[test] + fn empty_range() { + insta::assert_snapshot!(format_test_case("hello__ world"), @r" + Content: + Range: 0:5 -> 0:5 + Roundtrip: true + "); + } + + #[test] + fn emoji_single() { + insta::assert_snapshot!(format_test_case("hello _😀_ world"), @r" + Content: 😀 + Range: 0:6 -> 0:7 + Roundtrip: true + "); + } + + #[test] + fn emoji_multiple() { + insta::assert_snapshot!(format_test_case("_🎉🎊🎈_"), @r" + Content: 🎉🎊🎈 + Range: 0:0 -> 0:3 + Roundtrip: true + "); + } + + #[test] + fn multiline_within_line() { + insta::assert_snapshot!(format_test_case("line1\nl_ine2_\nline3"), @r" + Content: ine2 + Range: 1:1 -> 1:5 + Roundtrip: true + "); + } + + #[test] + fn multiline_spanning() { + insta::assert_snapshot!(format_test_case("li_ne1\nline2_\nline3"), @r" + Content: ne1 + line2 + Range: 0:2 -> 1:5 + Roundtrip: true + "); + } + + #[test] + fn multiline_crlf() { + insta::assert_snapshot!(format_test_case("line1\r\n_line2_\r\nline3"), @r" + Content: line2 + Range: 1:0 -> 1:5 + Roundtrip: true + "); + } + + #[test] + fn multiline_spanning_crlf() { + insta::assert_snapshot!(format_test_case("fir_st\r\nsec_ond\r\nthird"), @r" + Content: st + sec + Range: 0:3 -> 1:3 + Roundtrip: true + "); + } + + #[test] + fn multiline_with_empty_lines() { + insta::assert_snapshot!(format_test_case("_first\n\nthird_"), @r" + Content: first + + third + Range: 0:0 -> 2:5 + Roundtrip: true + "); + } + + #[test] + fn multiline_unicode() { + insta::assert_snapshot!(format_test_case("type _∷ Type\nvalue ∷_ Int\ndata ∷ Data"), @r" + Content: ∷ Type + value ∷ + Range: 0:5 -> 1:7 + Roundtrip: true + "); + } + + #[test] + fn file_start() { + insta::assert_snapshot!(format_test_case("_hello_\nworld"), @r" + Content: hello + Range: 0:0 -> 0:5 + Roundtrip: true + "); + } + + #[test] + fn file_end() { + insta::assert_snapshot!(format_test_case("hello\n_world_"), @r" + Content: world + Range: 1:0 -> 1:5 + Roundtrip: true + "); + } + + #[test] + fn entire_file() { + insta::assert_snapshot!(format_test_case("_first\nsecond\nthird_"), @r" + Content: first + second + third + Range: 0:0 -> 2:5 + Roundtrip: true + "); + } + + #[test] + fn empty_file() { + insta::assert_snapshot!(format_test_case("__"), @r" + Content: + Range: 0:0 -> 0:0 + Roundtrip: true + "); + } + + #[test] + fn mixed_line_endings() { + insta::assert_snapshot!(format_test_case("first\n_second\r\nthird_\nfourth"), @r" + Content: second + third + Range: 1:0 -> 2:5 + Roundtrip: true + "); + } + + #[test] + fn line_boundary() { + insta::assert_snapshot!(format_test_case("lin_e1_\nline2"), @r" + Content: e1 + Range: 0:3 -> 0:5 + Roundtrip: true + "); + } + + #[test] + fn starting_at_newline() { + insta::assert_snapshot!(format_test_case("line1_\nline2_"), @r" + Content: + line2 + Range: 0:5 -> 1:5 + Roundtrip: true + "); + } +} diff --git a/compiler-lsp/analyzer/src/symbols.rs b/compiler-lsp/analyzer/src/symbols.rs index 4fa4410a..06caa3ac 100644 --- a/compiler-lsp/analyzer/src/symbols.rs +++ b/compiler-lsp/analyzer/src/symbols.rs @@ -1,6 +1,8 @@ +use std::sync::Arc; + use async_lsp::lsp_types::*; use building::QueryEngine; -use files::{FileId, Files}; +use files::Files; use radix_trie::Trie; use crate::{AnalyzerError, common}; @@ -15,90 +17,85 @@ pub fn workspace( return Ok(None); } - let mut output = vec![]; - - let candidates = - if let Some(cached_files) = cache.get(query).or_else(|| cache.get_ancestor_value(query)) { - collect_files(engine, files, query, &mut output, cached_files.iter().copied())? - } else { - collect_files(engine, files, query, &mut output, files.iter_id())? - }; - - let query = query.to_string(); - cache.insert(query, candidates); - - Ok(Some(WorkspaceSymbolResponse::Flat(output))) -} + let query = query.to_lowercase(); -fn collect_files( - engine: &QueryEngine, - files: &Files, - query: &str, - output: &mut Vec, - to_search: impl Iterator, -) -> Result, AnalyzerError> { - let mut candidates = vec![]; + if let Some(exact_symbols) = cache.get(&query) { + tracing::debug!("Found exact match for '{query}'"); + let flat = Vec::clone(&*exact_symbols); + return Ok(Some(WorkspaceSymbolResponse::Flat(flat))); + } - for file_id in to_search { - let previous_size = output.len(); - collect_terms_types(engine, files, query, output, file_id)?; - if output.len() > previous_size { - candidates.push(file_id); + let symbols = if let Some(prefix_symbols) = cache.get_ancestor_value(&query) { + tracing::debug!("Found prefix match for '{query}'"); + let filtered_symbols = filter_symbols(prefix_symbols, &query); + if filtered_symbols.len() == prefix_symbols.len() { + Arc::clone(prefix_symbols) + } else { + Arc::new(filtered_symbols) } - } + } else { + tracing::debug!("Initialising cache for '{query}'"); + let filtered_symbols = build_symbol_list(engine, files, &query)?; + Arc::new(filtered_symbols) + }; + + let key = String::clone(&query); + let value = Arc::clone(&symbols); + cache.insert(key, value); + + let flat = Vec::clone(&*symbols); + Ok(Some(WorkspaceSymbolResponse::Flat(flat))) +} - Ok(candidates) +fn filter_symbols(cached: &[SymbolInformation], query: &str) -> Vec { + cached.iter().filter(|symbol| symbol.name.to_lowercase().starts_with(query)).cloned().collect() } -fn collect_terms_types( +fn build_symbol_list( engine: &QueryEngine, files: &Files, query: &str, - output: &mut Vec, - file_id: FileId, -) -> Result<(), AnalyzerError> { - let resolved = engine.resolved(file_id)?; - let uri = common::file_uri(engine, files, file_id)?; - - let terms = resolved - .locals - .iter_terms() - .filter_map(|(name, _, id)| if name.starts_with(query) { Some((name, id)) } else { None }); - - let types = resolved - .locals - .iter_types() - .filter_map(|(name, _, id)| if name.starts_with(query) { Some((name, id)) } else { None }); - - for (name, term_id) in terms { - let uri = uri.clone(); - let location = common::file_term_location(engine, uri, file_id, term_id)?; - output.push(SymbolInformation { - name: name.to_string(), - kind: SymbolKind::FUNCTION, - tags: None, - #[allow(deprecated)] - deprecated: None, - location, - container_name: None, - }) - } +) -> Result, AnalyzerError> { + let mut symbols = vec![]; + + for file_id in files.iter_id() { + let resolved = engine.resolved(file_id)?; + let uri = common::file_uri(engine, files, file_id)?; + + for (name, _, term_id) in resolved.locals.iter_terms() { + if !name.to_lowercase().starts_with(query) { + continue; + } + let location = common::file_term_location(engine, uri.clone(), file_id, term_id)?; + symbols.push(SymbolInformation { + name: name.to_string(), + kind: SymbolKind::FUNCTION, + tags: None, + #[allow(deprecated)] + deprecated: None, + location, + container_name: None, + }); + } - for (name, type_id) in types { - let uri = uri.clone(); - let location = common::file_type_location(engine, uri, file_id, type_id)?; - output.push(SymbolInformation { - name: name.to_string(), - kind: SymbolKind::CLASS, - tags: None, - #[allow(deprecated)] - deprecated: None, - location, - container_name: None, - }) + for (name, _, type_id) in resolved.locals.iter_types() { + if !name.to_lowercase().starts_with(query) { + continue; + } + let location = common::file_type_location(engine, uri.clone(), file_id, type_id)?; + symbols.push(SymbolInformation { + name: name.to_string(), + kind: SymbolKind::CLASS, + tags: None, + #[allow(deprecated)] + deprecated: None, + location, + container_name: None, + }); + } } - Ok(()) + Ok(symbols) } -pub type WorkspaceSymbolsCache = Trie>; +pub type WorkspaceSymbolsCache = Trie>>; From de365a2958268a6a32e0b5cd48c2d1cdcb9a04bb Mon Sep 17 00:00:00 2001 From: Justin Garcia Date: Thu, 13 Nov 2025 23:26:48 +0100 Subject: [PATCH 3/9] Implement basic subsumption rules --- compiler-core/checking/src/check/kind.rs | 2 +- .../checking/src/check/unification.rs | 27 +++++- compiler-core/checking/src/core.rs | 2 +- compiler-core/checking/src/core/pretty.rs | 5 +- tests-integration/tests/checking.rs | 93 +++++++++++++++++-- tests-integration/tests/lowering_scc.rs | 3 +- .../checking__quantify_multiple_scoped.snap | 2 +- .../checking__quantify_ordering.snap | 2 +- .../checking__quantify_polykind.snap | 2 +- .../snapshots/checking__quantify_scoped.snap | 2 +- .../snapshots/checking__quantify_simple.snap | 2 +- 11 files changed, 121 insertions(+), 21 deletions(-) diff --git a/compiler-core/checking/src/check/kind.rs b/compiler-core/checking/src/check/kind.rs index 2d2ca0bb..ad86d9c1 100644 --- a/compiler-core/checking/src/check/kind.rs +++ b/compiler-core/checking/src/check/kind.rs @@ -223,7 +223,7 @@ where Type::Variable(ref variable) => match variable { Variable::Implicit(_) => context.prim.unknown, - Variable::Skolem(_) => context.prim.unknown, + Variable::Skolem(_, kind) => *kind, Variable::Bound(index) => { let size = state.bound.size(); diff --git a/compiler-core/checking/src/check/unification.rs b/compiler-core/checking/src/check/unification.rs index a4eebdfa..e8c923fe 100644 --- a/compiler-core/checking/src/check/unification.rs +++ b/compiler-core/checking/src/check/unification.rs @@ -4,7 +4,7 @@ pub mod level; pub use context::*; use crate::ExternalQueries; -use crate::check::{CheckContext, CheckState, kind}; +use crate::check::{CheckContext, CheckState, kind, substitute}; use crate::core::{Type, TypeId, Variable, debruijn}; pub fn subsumes( @@ -18,11 +18,32 @@ where { let t1 = state.normalize_type(t1); let t2 = state.normalize_type(t2); - match (&state.storage[t1], &state.storage[t2]) { - (&Type::Function(t1_argument, t1_result), &Type::Function(t2_argument, t2_result)) => { + + let t1_core = state.storage[t1].clone(); + let t2_core = state.storage[t2].clone(); + + match (t1_core, t2_core) { + (Type::Function(t1_argument, t1_result), Type::Function(t2_argument, t2_result)) => { subsumes(state, context, t2_argument, t1_argument) && subsumes(state, context, t1_result, t2_result) } + + (_, Type::Forall(ref binder, inner)) => { + let v = Variable::Skolem(binder.level, binder.kind); + let t = state.storage.intern(Type::Variable(v)); + + let inner = substitute::substitute_bound(state, t, inner); + subsumes(state, context, t1, inner) + } + + (Type::Forall(ref binder, inner), _) => { + let k = state.normalize_type(binder.kind); + let t = state.fresh_unification_kinded(k); + + let inner = substitute::substitute_bound(state, t, inner); + subsumes(state, context, inner, t2) + } + _ => unify(state, context, t1, t2), } } diff --git a/compiler-core/checking/src/core.rs b/compiler-core/checking/src/core.rs index 95f61f89..f3e45ebd 100644 --- a/compiler-core/checking/src/core.rs +++ b/compiler-core/checking/src/core.rs @@ -23,7 +23,7 @@ pub struct ForallBinder { #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum Variable { Implicit(debruijn::Level), - Skolem(debruijn::Level), + Skolem(debruijn::Level, TypeId), Bound(debruijn::Index), Free(SmolStr), } diff --git a/compiler-core/checking/src/core/pretty.rs b/compiler-core/checking/src/core/pretty.rs index fa1ccce3..4106954f 100644 --- a/compiler-core/checking/src/core/pretty.rs +++ b/compiler-core/checking/src/core/pretty.rs @@ -139,7 +139,10 @@ fn traverse<'a, Q: ExternalQueries>(source: &mut TraversalSource<'a, Q>, id: Typ Type::Variable(ref variable) => match variable { Variable::Implicit(level) => format!("{level}"), - Variable::Skolem(level) => format!("~{level}"), + Variable::Skolem(level, kind) => { + let kind = traverse(source, *kind); + format!("~{level} :: {kind}") + } Variable::Bound(index) => format!("{index}"), Variable::Free(name) => format!("{name}"), }, diff --git a/tests-integration/tests/checking.rs b/tests-integration/tests/checking.rs index a62affb8..a16e43cc 100644 --- a/tests-integration/tests/checking.rs +++ b/tests-integration/tests/checking.rs @@ -4,7 +4,7 @@ use std::num::NonZeroU32; use analyzer::{QueryEngine, prim}; use checking::check::unification::{self, UnificationState}; use checking::check::{CheckContext, CheckState, quantify}; -use checking::core::{Type, TypeId, Variable, debruijn, pretty}; +use checking::core::{ForallBinder, Type, TypeId, Variable, debruijn, pretty}; use files::{FileId, Files}; use lowering::TypeVariableBindingId; @@ -271,17 +271,94 @@ fn test_quantify_multiple_scoped() { insta::assert_snapshot!(snapshot) } +fn make_forall_a_to_a(context: &CheckContext, state: &mut CheckState) -> TypeId { + let fake_id = TypeVariableBindingId::new(FAKE_NONZERO_1); + + let level = state.bind_forall(fake_id, context.prim.t); + + let bound_a = state.bound_variable(0); + let a_to_a = state.function(bound_a, bound_a); + + let binder = ForallBinder { visible: false, name: "a".into(), level, kind: context.prim.t }; + let forall_a_to_a = state.storage.intern(Type::Forall(binder, a_to_a)); + + state.unbind(level); + + forall_a_to_a +} + #[test] -fn test_manual() { +fn test_subsumes_forall_left_pass() { let (engine, id) = empty_engine(); + let ContextState { ref context, ref mut state } = ContextState::new(&engine, id); + + // Given ∀a. (a -> a) + let forall_a_to_a = make_forall_a_to_a(context, state); + + // ∀a. (a -> a) should subsume (Int -> Int) + let int_to_int = state.function(context.prim.int, context.prim.int); + let result = unification::subsumes(state, context, forall_a_to_a, int_to_int); + assert!(result, "∀a. (a -> a) should subsume (Int -> Int)"); +} - engine.set_content(id, "module Main where\n\ndata Either a b = Left a | Right b"); +#[test] +fn test_subsumes_forall_left_fail() { + let (engine, id) = empty_engine(); + let ContextState { ref context, ref mut state } = ContextState::new(&engine, id); - let resolved = engine.resolved(id).unwrap(); - let checked = engine.checked(id).unwrap(); + // Given ∀a. (a -> a) + let forall_a_to_a = make_forall_a_to_a(context, state); - // let (_, id) = resolved.locals.lookup_type("T").unwrap(); - // let id = checked.lookup_type(id).unwrap(); + // ∀a. (a -> a) should NOT subsume (Int -> String) + let int_to_string = state.function(context.prim.int, context.prim.string); + let result = unification::subsumes(state, context, forall_a_to_a, int_to_string); + assert!(!result, "∀a. (a -> a) should not subsume (Int -> String)"); +} - // eprintln!("{}", pretty::print_global(&engine, id)); +#[test] +fn test_subsumes_forall_right_fail() { + let (engine, id) = empty_engine(); + let ContextState { ref context, ref mut state } = ContextState::new(&engine, id); + + // Create ∀a. a + let forall_a_to_a = make_forall_a_to_a(context, state); + + // Int should NOT subsume ∀a. a + let int_to_int = state.function(context.prim.int, context.prim.int); + let result = unification::subsumes(state, context, int_to_int, forall_a_to_a); + assert!(!result, "(Int -> Int) should not subsume ∀a. a -> a"); +} + +#[test] +fn test_subsumes_nested_forall() { + let (engine, id) = empty_engine(); + let ContextState { ref context, ref mut state } = ContextState::new(&engine, id); + + // Create ∀a. ∀b. (a -> b -> a) + let level_a = state.bind_forall(TypeVariableBindingId::new(FAKE_NONZERO_1), context.prim.t); + let level_b = state.bind_forall(TypeVariableBindingId::new(FAKE_NONZERO_2), context.prim.t); + + let bound_a = state.bound_variable(1); + let bound_b = state.bound_variable(0); + let b_to_a = state.function(bound_b, bound_a); + let a_to_b_to_a = state.function(bound_a, b_to_a); + + let forall_b = state.storage.intern(Type::Forall( + ForallBinder { visible: false, name: "b".into(), level: level_b, kind: context.prim.t }, + a_to_b_to_a, + )); + state.unbind(level_b); + + let forall_a_b = state.storage.intern(Type::Forall( + ForallBinder { visible: false, name: "a".into(), level: level_a, kind: context.prim.t }, + forall_b, + )); + state.unbind(level_a); + + // ∀a. ∀b. (a -> b -> a) should subsume (Int -> String -> Int) + let string_to_int = state.function(context.prim.string, context.prim.int); + let int_to_string_to_int = state.function(context.prim.int, string_to_int); + + let result = unification::subsumes(state, context, forall_a_b, int_to_string_to_int); + assert!(result, "∀a. ∀b. (a -> b -> a) should subsume (Int -> String -> Int)"); } diff --git a/tests-integration/tests/lowering_scc.rs b/tests-integration/tests/lowering_scc.rs index 9dbe8831..840eb50d 100644 --- a/tests-integration/tests/lowering_scc.rs +++ b/tests-integration/tests/lowering_scc.rs @@ -79,7 +79,7 @@ infix 5 type Add as + } #[test] -fn test_non_cycle_ordering() {{ +fn test_non_cycle_ordering() { let mut engine = QueryEngine::default(); let mut files = Files::default(); prim::configure(&mut engine, &mut files); @@ -105,4 +105,3 @@ c _ = 0 insta::assert_debug_snapshot!((terms, types)); } -} diff --git a/tests-integration/tests/snapshots/checking__quantify_multiple_scoped.snap b/tests-integration/tests/snapshots/checking__quantify_multiple_scoped.snap index 90ec330f..16be784b 100644 --- a/tests-integration/tests/snapshots/checking__quantify_multiple_scoped.snap +++ b/tests-integration/tests/snapshots/checking__quantify_multiple_scoped.snap @@ -2,4 +2,4 @@ source: tests-integration/tests/checking.rs expression: snapshot --- -forall (t0 :: Type) (t1 :: *0) (t2 :: *0) (t3 :: Type) (t4 :: *0) (t5 :: *0) (*3 -> *0) +forall (t0 :: Type) (t1 :: *0) (t2 :: *0) (t3 :: Type) (t4 :: *0) (t5 :: *0). (*3 -> *0) diff --git a/tests-integration/tests/snapshots/checking__quantify_ordering.snap b/tests-integration/tests/snapshots/checking__quantify_ordering.snap index 90e99a57..fd548d89 100644 --- a/tests-integration/tests/snapshots/checking__quantify_ordering.snap +++ b/tests-integration/tests/snapshots/checking__quantify_ordering.snap @@ -2,4 +2,4 @@ source: tests-integration/tests/checking.rs expression: snapshot --- -forall (t0 :: Type) (t1 :: Type) (*0 -> *1) +forall (t0 :: Type) (t1 :: Type). (*0 -> *1) diff --git a/tests-integration/tests/snapshots/checking__quantify_polykind.snap b/tests-integration/tests/snapshots/checking__quantify_polykind.snap index 7479ffff..cddfa5df 100644 --- a/tests-integration/tests/snapshots/checking__quantify_polykind.snap +++ b/tests-integration/tests/snapshots/checking__quantify_polykind.snap @@ -2,4 +2,4 @@ source: tests-integration/tests/checking.rs expression: snapshot --- -forall (t0 :: Type) (t1 :: *0) *0 +forall (t0 :: Type) (t1 :: *0). *0 diff --git a/tests-integration/tests/snapshots/checking__quantify_scoped.snap b/tests-integration/tests/snapshots/checking__quantify_scoped.snap index 19165759..5ca7bbcb 100644 --- a/tests-integration/tests/snapshots/checking__quantify_scoped.snap +++ b/tests-integration/tests/snapshots/checking__quantify_scoped.snap @@ -2,4 +2,4 @@ source: tests-integration/tests/checking.rs expression: snapshot --- -forall (t0 :: Type) (t1 :: *0) (t2 :: *0) *0 +forall (t0 :: Type) (t1 :: *0) (t2 :: *0). *0 diff --git a/tests-integration/tests/snapshots/checking__quantify_simple.snap b/tests-integration/tests/snapshots/checking__quantify_simple.snap index 07ef2047..5acf8c31 100644 --- a/tests-integration/tests/snapshots/checking__quantify_simple.snap +++ b/tests-integration/tests/snapshots/checking__quantify_simple.snap @@ -2,4 +2,4 @@ source: tests-integration/tests/checking.rs expression: snapshot --- -forall (t0 :: Type) (t1 :: Type) (*1 -> *0) +forall (t0 :: Type) (t1 :: Type). (*1 -> *0) From e82716babefc26dd8aff7a7da1446997cd49979f Mon Sep 17 00:00:00 2001 From: Justin Garcia Date: Thu, 13 Nov 2025 23:48:36 +0100 Subject: [PATCH 4/9] Apply clippy fixes --- compiler-core/checking/src/lib.rs | 2 +- compiler-lsp/analyzer/src/common.rs | 2 +- compiler-lsp/analyzer/src/symbols.rs | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/compiler-core/checking/src/lib.rs b/compiler-core/checking/src/lib.rs index e900e37c..f27e7467 100644 --- a/compiler-core/checking/src/lib.rs +++ b/compiler-core/checking/src/lib.rs @@ -92,7 +92,7 @@ where .map(|signature| convert::signature_type_to_core(state, context, signature)); if let Some(DataIr { variables }) = data { - let inferred_type = create_type_declaration_kind(state, context, &variables); + let inferred_type = create_type_declaration_kind(state, context, variables); for constructor_id in context.indexed.pairs.data_constructors(item_id) { let Some(TermItemIr::Constructor { arguments }) = context.lowered.info.get_term_item(constructor_id) diff --git a/compiler-lsp/analyzer/src/common.rs b/compiler-lsp/analyzer/src/common.rs index 9079f095..1292c20e 100644 --- a/compiler-lsp/analyzer/src/common.rs +++ b/compiler-lsp/analyzer/src/common.rs @@ -63,7 +63,7 @@ fn pointers_range( pointers: impl Iterator, ) -> Result { pointers - .filter_map(|ptr| locate::syntax_range(&content, &root, &ptr)) + .filter_map(|ptr| locate::syntax_range(content, &root, &ptr)) .reduce(|start, end| Range { start: start.start, end: end.end }) .ok_or(AnalyzerError::NonFatal) } diff --git a/compiler-lsp/analyzer/src/symbols.rs b/compiler-lsp/analyzer/src/symbols.rs index 06caa3ac..d90edebf 100644 --- a/compiler-lsp/analyzer/src/symbols.rs +++ b/compiler-lsp/analyzer/src/symbols.rs @@ -21,7 +21,7 @@ pub fn workspace( if let Some(exact_symbols) = cache.get(&query) { tracing::debug!("Found exact match for '{query}'"); - let flat = Vec::clone(&*exact_symbols); + let flat = Vec::clone(exact_symbols); return Ok(Some(WorkspaceSymbolResponse::Flat(flat))); } From 0318990d923f3270f77301c1be50e0c0972a2e13 Mon Sep 17 00:00:00 2001 From: Justin Garcia Date: Fri, 14 Nov 2025 09:40:28 +0800 Subject: [PATCH 5/9] Small copy changes --- tests-integration/tests/checking.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests-integration/tests/checking.rs b/tests-integration/tests/checking.rs index a16e43cc..ec8aad0f 100644 --- a/tests-integration/tests/checking.rs +++ b/tests-integration/tests/checking.rs @@ -320,13 +320,13 @@ fn test_subsumes_forall_right_fail() { let (engine, id) = empty_engine(); let ContextState { ref context, ref mut state } = ContextState::new(&engine, id); - // Create ∀a. a + // Create ∀a. a -> a let forall_a_to_a = make_forall_a_to_a(context, state); - // Int should NOT subsume ∀a. a + // (Int -> Int) should NOT subsume ∀a. (a -> a) let int_to_int = state.function(context.prim.int, context.prim.int); let result = unification::subsumes(state, context, int_to_int, forall_a_to_a); - assert!(!result, "(Int -> Int) should not subsume ∀a. a -> a"); + assert!(!result, "(Int -> Int) should not subsume ∀a. (a -> a)"); } #[test] From 1dba9ccb51f213affb0ec20a76fdaac6da4b19e8 Mon Sep 17 00:00:00 2001 From: Justin Garcia Date: Mon, 17 Nov 2025 05:22:55 +0800 Subject: [PATCH 6/9] Update snapshots to reflect API change --- ...ing__001_ado_statement_recursion_main.snap | 8 ++-- .../lowering__002_class_equation_main.snap | 44 +++++++++--------- .../lowering__003_data_equation_main.snap | 20 ++++---- ...lowering__004_derive_declaration_main.snap | 4 +- .../lowering__005_do_statement_main.snap | 24 +++++----- ...ring__006_do_statement_recursion_main.snap | 18 ++++---- ...wering__007_instance_declaration_main.snap | 46 +++++++++---------- .../lowering__008_newtype_equation_main.snap | 16 +++---- ...lowering__009_signature_equation_main.snap | 12 ++--- .../lowering__010_value_equation_main.snap | 44 +++++++++--------- 10 files changed, 118 insertions(+), 118 deletions(-) diff --git a/tests-integration/tests/snapshots/lowering__001_ado_statement_recursion_main.snap b/tests-integration/tests/snapshots/lowering__001_ado_statement_recursion_main.snap index 72cf186a..299de31c 100644 --- a/tests-integration/tests/snapshots/lowering__001_ado_statement_recursion_main.snap +++ b/tests-integration/tests/snapshots/lowering__001_ado_statement_recursion_main.snap @@ -6,11 +6,11 @@ module Main Expressions: -x'notBinder@Position { line: 4, character: 4 } - resolves to binder Position { line: 2, character: 10 } -x'notBinder@Position { line: 3, character: 21 } +x'notBinder@Some(Position { line: 4, character: 4 }) + resolves to binder Some(Position { line: 2, character: 10 }) +x'notBinder@Some(Position { line: 3, character: 21 }) resolves to top-level name -pure@Position { line: 3, character: 16 } +pure@Some(Position { line: 3, character: 16 }) resolves to top-level name Types: diff --git a/tests-integration/tests/snapshots/lowering__002_class_equation_main.snap b/tests-integration/tests/snapshots/lowering__002_class_equation_main.snap index 6504ae20..8ed1eb37 100644 --- a/tests-integration/tests/snapshots/lowering__002_class_equation_main.snap +++ b/tests-integration/tests/snapshots/lowering__002_class_equation_main.snap @@ -9,25 +9,25 @@ Expressions: Types: -k@Position { line: 2, character: 30 } - resolves to forall Position { line: 2, character: 22 } -p@Position { line: 4, character: 20 } - resolves to forall Position { line: 4, character: 17 } -p@Position { line: 4, character: 27 } - resolves to forall Position { line: 4, character: 17 } -k@Position { line: 2, character: 25 } - resolves to forall Position { line: 2, character: 22 } -a@Position { line: 7, character: 17 } - resolves to forall Position { line: 6, character: 17 } -a@Position { line: 6, character: 8 } - resolves to forall Position { line: 6, character: 17 } -a@Position { line: 4, character: 22 } - resolves to forall Position { line: 3, character: 12 } -k@Position { line: 3, character: 27 } - resolves to forall Position { line: 2, character: 22 } -b@Position { line: 4, character: 29 } - resolves to forall Position { line: 3, character: 21 } -a@Position { line: 7, character: 12 } - resolves to forall Position { line: 6, character: 17 } -k@Position { line: 3, character: 18 } - resolves to forall Position { line: 2, character: 22 } +k@Some(Position { line: 2, character: 30 }) + resolves to forall Some(Position { line: 2, character: 22 }) +p@Some(Position { line: 4, character: 20 }) + resolves to forall Some(Position { line: 4, character: 17 }) +p@Some(Position { line: 4, character: 27 }) + resolves to forall Some(Position { line: 4, character: 17 }) +k@Some(Position { line: 2, character: 25 }) + resolves to forall Some(Position { line: 2, character: 22 }) +a@Some(Position { line: 7, character: 17 }) + resolves to forall Some(Position { line: 6, character: 17 }) +a@Some(Position { line: 6, character: 8 }) + resolves to forall Some(Position { line: 6, character: 17 }) +a@Some(Position { line: 4, character: 22 }) + resolves to forall Some(Position { line: 3, character: 12 }) +k@Some(Position { line: 3, character: 27 }) + resolves to forall Some(Position { line: 2, character: 22 }) +b@Some(Position { line: 4, character: 29 }) + resolves to forall Some(Position { line: 3, character: 21 }) +a@Some(Position { line: 7, character: 12 }) + resolves to forall Some(Position { line: 6, character: 17 }) +k@Some(Position { line: 3, character: 18 }) + resolves to forall Some(Position { line: 2, character: 22 }) diff --git a/tests-integration/tests/snapshots/lowering__003_data_equation_main.snap b/tests-integration/tests/snapshots/lowering__003_data_equation_main.snap index e0b9a22c..3b308eec 100644 --- a/tests-integration/tests/snapshots/lowering__003_data_equation_main.snap +++ b/tests-integration/tests/snapshots/lowering__003_data_equation_main.snap @@ -9,13 +9,13 @@ Expressions: Types: -a@Position { line: 4, character: 22 } - resolves to forall Position { line: 4, character: 11 } -k@Position { line: 6, character: 23 } - resolves to forall Position { line: 6, character: 20 } -k@Position { line: 7, character: 16 } - resolves to forall Position { line: 6, character: 20 } -b@Position { line: 4, character: 32 } - resolves to forall Position { line: 4, character: 13 } -a@Position { line: 2, character: 19 } - resolves to forall Position { line: 2, character: 10 } +a@Some(Position { line: 4, character: 22 }) + resolves to forall Some(Position { line: 4, character: 11 }) +k@Some(Position { line: 6, character: 23 }) + resolves to forall Some(Position { line: 6, character: 20 }) +k@Some(Position { line: 7, character: 16 }) + resolves to forall Some(Position { line: 6, character: 20 }) +b@Some(Position { line: 4, character: 32 }) + resolves to forall Some(Position { line: 4, character: 13 }) +a@Some(Position { line: 2, character: 19 }) + resolves to forall Some(Position { line: 2, character: 10 }) diff --git a/tests-integration/tests/snapshots/lowering__004_derive_declaration_main.snap b/tests-integration/tests/snapshots/lowering__004_derive_declaration_main.snap index 0f012d6e..cf2de88c 100644 --- a/tests-integration/tests/snapshots/lowering__004_derive_declaration_main.snap +++ b/tests-integration/tests/snapshots/lowering__004_derive_declaration_main.snap @@ -9,7 +9,7 @@ Expressions: Types: -a@Position { line: 2, character: 25 } +a@Some(Position { line: 2, character: 25 }) introduces a constraint variable "a" -a@Position { line: 3, character: 34 } +a@Some(Position { line: 3, character: 34 }) introduces a constraint variable "a" diff --git a/tests-integration/tests/snapshots/lowering__005_do_statement_main.snap b/tests-integration/tests/snapshots/lowering__005_do_statement_main.snap index 012f8b60..2f364071 100644 --- a/tests-integration/tests/snapshots/lowering__005_do_statement_main.snap +++ b/tests-integration/tests/snapshots/lowering__005_do_statement_main.snap @@ -6,23 +6,23 @@ module Main Expressions: -x@Position { line: 3, character: 8 } +x@Some(Position { line: 3, character: 8 }) resolves to top-level name -pure@Position { line: 5, character: 6 } +pure@Some(Position { line: 5, character: 6 }) resolves to top-level name -pure@Position { line: 6, character: 12 } +pure@Some(Position { line: 6, character: 12 }) resolves to top-level name -x@Position { line: 7, character: 8 } - resolves to equation Position { line: 4, character: 5 } -action@Position { line: 2, character: 9 } +x@Some(Position { line: 7, character: 8 }) + resolves to equation Some(Position { line: 4, character: 5 }) +action@Some(Position { line: 2, character: 9 }) resolves to top-level name -y@Position { line: 7, character: 12 } - resolves to binder Position { line: 4, character: 12 } -z@Position { line: 7, character: 16 } - resolves to equation Position { line: 6, character: 5 } -z@Position { line: 3, character: 12 } +y@Some(Position { line: 7, character: 12 }) + resolves to binder Some(Position { line: 4, character: 12 }) +z@Some(Position { line: 7, character: 16 }) + resolves to equation Some(Position { line: 6, character: 5 }) +z@Some(Position { line: 3, character: 12 }) resolves to top-level name -y@Position { line: 3, character: 10 } +y@Some(Position { line: 3, character: 10 }) resolves to top-level name Types: diff --git a/tests-integration/tests/snapshots/lowering__006_do_statement_recursion_main.snap b/tests-integration/tests/snapshots/lowering__006_do_statement_recursion_main.snap index b6ddca19..55a76d36 100644 --- a/tests-integration/tests/snapshots/lowering__006_do_statement_recursion_main.snap +++ b/tests-integration/tests/snapshots/lowering__006_do_statement_recursion_main.snap @@ -6,17 +6,17 @@ module Main Expressions: -y'equation@Position { line: 3, character: 27 } - resolves to equation Position { line: 3, character: 5 } -pure@Position { line: 4, character: 33 } +y'equation@Some(Position { line: 3, character: 27 }) + resolves to equation Some(Position { line: 3, character: 5 }) +pure@Some(Position { line: 4, character: 33 }) resolves to top-level name -pure@Position { line: 4, character: 16 } +pure@Some(Position { line: 4, character: 16 }) resolves to top-level name -a'binder@Position { line: 3, character: 38 } - resolves to binder Position { line: 3, character: 16 } -x'notBinder@Position { line: 5, character: 6 } - resolves to binder Position { line: 3, character: 47 } -x'notBinder@Position { line: 4, character: 21 } +a'binder@Some(Position { line: 3, character: 38 }) + resolves to binder Some(Position { line: 3, character: 16 }) +x'notBinder@Some(Position { line: 5, character: 6 }) + resolves to binder Some(Position { line: 3, character: 47 }) +x'notBinder@Some(Position { line: 4, character: 21 }) resolves to top-level name Types: diff --git a/tests-integration/tests/snapshots/lowering__007_instance_declaration_main.snap b/tests-integration/tests/snapshots/lowering__007_instance_declaration_main.snap index 70564534..1af9b597 100644 --- a/tests-integration/tests/snapshots/lowering__007_instance_declaration_main.snap +++ b/tests-integration/tests/snapshots/lowering__007_instance_declaration_main.snap @@ -6,39 +6,39 @@ module Main Expressions: -eqMaybeImpl@Position { line: 8, character: 6 } +eqMaybeImpl@Some(Position { line: 8, character: 6 }) resolves to top-level name -eqIntImpl@Position { line: 4, character: 6 } +eqIntImpl@Some(Position { line: 4, character: 6 }) resolves to top-level name -b@Position { line: 12, character: 11 } - resolves to binder Position { line: 12, character: 7 } +b@Some(Position { line: 12, character: 11 }) + resolves to binder Some(Position { line: 12, character: 7 }) Types: -a@Position { line: 6, character: 11 } +a@Some(Position { line: 6, character: 11 }) resolves to a constraint variable "a" - Position { line: 6, character: 26 } -b@Position { line: 10, character: 21 } + Some(Position { line: 6, character: 26 }) +b@Some(Position { line: 10, character: 21 }) introduces a constraint variable "b" -a@Position { line: 7, character: 24 } +a@Some(Position { line: 7, character: 24 }) resolves to a constraint variable "a" - Position { line: 6, character: 26 } -b@Position { line: 10, character: 19 } + Some(Position { line: 6, character: 26 }) +b@Some(Position { line: 10, character: 19 }) introduces a constraint variable "b" -a@Position { line: 6, character: 26 } +a@Some(Position { line: 6, character: 26 }) introduces a constraint variable "a" -a@Position { line: 7, character: 13 } +a@Some(Position { line: 7, character: 13 }) resolves to a constraint variable "a" - Position { line: 6, character: 26 } -b@Position { line: 11, character: 22 } + Some(Position { line: 6, character: 26 }) +b@Some(Position { line: 11, character: 22 }) resolves to a constraint variable "b" - Position { line: 10, character: 19 } - Position { line: 10, character: 21 } -b@Position { line: 11, character: 29 } + Some(Position { line: 10, character: 19 }) + Some(Position { line: 10, character: 21 }) +b@Some(Position { line: 11, character: 29 }) resolves to a constraint variable "b" - Position { line: 10, character: 19 } - Position { line: 10, character: 21 } -p@Position { line: 11, character: 20 } - resolves to forall Position { line: 11, character: 17 } -p@Position { line: 11, character: 27 } - resolves to forall Position { line: 11, character: 17 } + Some(Position { line: 10, character: 19 }) + Some(Position { line: 10, character: 21 }) +p@Some(Position { line: 11, character: 20 }) + resolves to forall Some(Position { line: 11, character: 17 }) +p@Some(Position { line: 11, character: 27 }) + resolves to forall Some(Position { line: 11, character: 17 }) diff --git a/tests-integration/tests/snapshots/lowering__008_newtype_equation_main.snap b/tests-integration/tests/snapshots/lowering__008_newtype_equation_main.snap index cd9f5d62..94b3f82d 100644 --- a/tests-integration/tests/snapshots/lowering__008_newtype_equation_main.snap +++ b/tests-integration/tests/snapshots/lowering__008_newtype_equation_main.snap @@ -9,11 +9,11 @@ Expressions: Types: -a@Position { line: 5, character: 29 } - resolves to forall Position { line: 5, character: 9 } -k@Position { line: 5, character: 15 } - resolves to forall Position { line: 4, character: 19 } -k@Position { line: 4, character: 22 } - resolves to forall Position { line: 4, character: 19 } -a@Position { line: 2, character: 17 } - resolves to forall Position { line: 2, character: 10 } +a@Some(Position { line: 5, character: 29 }) + resolves to forall Some(Position { line: 5, character: 9 }) +k@Some(Position { line: 5, character: 15 }) + resolves to forall Some(Position { line: 4, character: 19 }) +k@Some(Position { line: 4, character: 22 }) + resolves to forall Some(Position { line: 4, character: 19 }) +a@Some(Position { line: 2, character: 17 }) + resolves to forall Some(Position { line: 2, character: 10 }) diff --git a/tests-integration/tests/snapshots/lowering__009_signature_equation_main.snap b/tests-integration/tests/snapshots/lowering__009_signature_equation_main.snap index 4d0029b0..3b7cdb79 100644 --- a/tests-integration/tests/snapshots/lowering__009_signature_equation_main.snap +++ b/tests-integration/tests/snapshots/lowering__009_signature_equation_main.snap @@ -9,9 +9,9 @@ Expressions: Types: -k@Position { line: 3, character: 12 } - resolves to forall Position { line: 2, character: 16 } -k@Position { line: 2, character: 19 } - resolves to forall Position { line: 2, character: 16 } -a@Position { line: 3, character: 23 } - resolves to forall Position { line: 3, character: 6 } +k@Some(Position { line: 3, character: 12 }) + resolves to forall Some(Position { line: 2, character: 16 }) +k@Some(Position { line: 2, character: 19 }) + resolves to forall Some(Position { line: 2, character: 16 }) +a@Some(Position { line: 3, character: 23 }) + resolves to forall Some(Position { line: 3, character: 6 }) diff --git a/tests-integration/tests/snapshots/lowering__010_value_equation_main.snap b/tests-integration/tests/snapshots/lowering__010_value_equation_main.snap index 4ef4c8cf..615c35fe 100644 --- a/tests-integration/tests/snapshots/lowering__010_value_equation_main.snap +++ b/tests-integration/tests/snapshots/lowering__010_value_equation_main.snap @@ -6,29 +6,29 @@ module Main Expressions: -a@Position { line: 3, character: 15 } - resolves to binder Position { line: 3, character: 4 } -z@Position { line: 10, character: 4 } - resolves to signature Position { line: 7, character: 5 } - resolves to equation Position { line: 8, character: 12 } -add@Position { line: 14, character: 6 } +a@Some(Position { line: 3, character: 15 }) + resolves to binder Some(Position { line: 3, character: 4 }) +z@Some(Position { line: 10, character: 4 }) + resolves to signature Some(Position { line: 7, character: 5 }) + resolves to equation Some(Position { line: 8, character: 12 }) +add@Some(Position { line: 14, character: 6 }) resolves to top-level name -a@Position { line: 9, character: 7 } - resolves to binder Position { line: 6, character: 5 } +a@Some(Position { line: 9, character: 7 }) + resolves to binder Some(Position { line: 6, character: 5 }) Types: -a@Position { line: 3, character: 19 } - resolves to forall Position { line: 2, character: 12 } -a@Position { line: 2, character: 20 } - resolves to forall Position { line: 2, character: 12 } -a@Position { line: 5, character: 20 } - resolves to forall Position { line: 5, character: 15 } -a@Position { line: 2, character: 15 } - resolves to forall Position { line: 2, character: 12 } -a@Position { line: 5, character: 30 } - resolves to forall Position { line: 5, character: 15 } -a@Position { line: 3, character: 8 } - resolves to forall Position { line: 2, character: 12 } -b@Position { line: 5, character: 25 } - resolves to forall Position { line: 5, character: 17 } +a@Some(Position { line: 3, character: 19 }) + resolves to forall Some(Position { line: 2, character: 12 }) +a@Some(Position { line: 2, character: 20 }) + resolves to forall Some(Position { line: 2, character: 12 }) +a@Some(Position { line: 5, character: 20 }) + resolves to forall Some(Position { line: 5, character: 15 }) +a@Some(Position { line: 2, character: 15 }) + resolves to forall Some(Position { line: 2, character: 12 }) +a@Some(Position { line: 5, character: 30 }) + resolves to forall Some(Position { line: 5, character: 15 }) +a@Some(Position { line: 3, character: 8 }) + resolves to forall Some(Position { line: 2, character: 12 }) +b@Some(Position { line: 5, character: 25 }) + resolves to forall Some(Position { line: 5, character: 17 }) From 79fad0c19ce41efdb637e61a249899c69b1eaba5 Mon Sep 17 00:00:00 2001 From: Justin Garcia Date: Mon, 17 Nov 2025 05:23:19 +0800 Subject: [PATCH 7/9] Fix pretty printing traversal for function types --- compiler-core/checking/src/core/pretty.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compiler-core/checking/src/core/pretty.rs b/compiler-core/checking/src/core/pretty.rs index 4106954f..01b24545 100644 --- a/compiler-core/checking/src/core/pretty.rs +++ b/compiler-core/checking/src/core/pretty.rs @@ -104,7 +104,7 @@ fn traverse<'a, Q: ExternalQueries>(source: &mut TraversalSource<'a, Q>, id: Typ let result = traverse(source, result); let arguments = - arguments.iter().rev().map(|argument| traverse(source, *argument)).join(" -> "); + arguments.iter().map(|argument| traverse(source, *argument)).join(" -> "); format!("({arguments} -> {result})") } From 6dc5a87718401e522c69f0e7fad34100ce2b92e0 Mon Sep 17 00:00:00 2001 From: Justin Garcia Date: Tue, 18 Nov 2025 00:25:49 +0800 Subject: [PATCH 8/9] Implement initial type checking for data declarations --- compiler-core/checking/src/check/convert.rs | 89 +++++++- compiler-core/checking/src/check/kind.rs | 2 +- compiler-core/checking/src/check/state.rs | 5 +- .../checking/src/check/unification.rs | 10 + compiler-core/checking/src/lib.rs | 203 ++++++++++++------ .../lowering/src/algorithm/recursive.rs | 8 +- compiler-core/lowering/src/intermediate.rs | 2 +- tests-integration/tests/checking.rs | 93 ++++++++ 8 files changed, 328 insertions(+), 84 deletions(-) diff --git a/compiler-core/checking/src/check/convert.rs b/compiler-core/checking/src/check/convert.rs index 980abbeb..f93866ac 100644 --- a/compiler-core/checking/src/check/convert.rs +++ b/compiler-core/checking/src/check/convert.rs @@ -1,3 +1,5 @@ +use std::iter; + use itertools::Itertools; use smol_str::SmolStr; @@ -39,13 +41,13 @@ where }; state.storage.intern(Type::Constructor(file_id, type_id)) } - lowering::TypeKind::Forall { bindings, type_ } => { + lowering::TypeKind::Forall { bindings, inner } => { let binders = bindings .iter() .map(|binding| convert_forall_binding(state, context, binding)) .collect_vec(); - let inner = type_.map_or(default, |id| type_to_core(state, context, id)); + let inner = inner.map_or(default, |id| type_to_core(state, context, id)); let forall = binders .into_iter() @@ -65,7 +67,7 @@ where lowering::TypeKind::String => default, lowering::TypeKind::Variable { name, resolution } => { let Some(resolution) = resolution else { - let name = name.clone().unwrap_or(INVALID_NAME); + let name = name.clone().unwrap_or(MISSING_NAME); let kind = Variable::Free(name); return state.storage.intern(Type::Variable(kind)); }; @@ -111,13 +113,13 @@ where }; match kind { - lowering::TypeKind::Forall { bindings, type_ } => { + lowering::TypeKind::Forall { bindings, inner } => { let binders = bindings .iter() .map(|binding| convert_forall_binding(state, context, binding)) .collect_vec(); - let inner = type_.map_or(default, |id| type_to_core(state, context, id)); + let inner = inner.map_or(default, |id| type_to_core(state, context, id)); binders .into_iter() @@ -132,7 +134,78 @@ where } } -const INVALID_NAME: SmolStr = SmolStr::new_inline(""); +pub struct InspectSignature { + pub variables: Vec, + pub arguments: Vec, + pub result: TypeId, +} + +pub fn inspect_signature( + state: &mut CheckState, + context: &CheckContext, + id: lowering::TypeId, +) -> InspectSignature +where + Q: ExternalQueries, +{ + let unknown = || { + let variables = [].into(); + let arguments = [].into(); + let result = context.prim.unknown; + InspectSignature { variables, arguments, result } + }; + + let Some(kind) = context.lowered.info.get_type_kind(id) else { + return unknown(); + }; + + match kind { + lowering::TypeKind::Forall { bindings, inner } => { + let variables = bindings + .iter() + .map(|binding| convert_forall_binding(state, context, binding)) + .collect(); + + let inner = inner.map_or(context.prim.unknown, |id| type_to_core(state, context, id)); + let (arguments, result) = signature_components(state, inner); + + InspectSignature { variables, arguments, result } + } + + lowering::TypeKind::Parenthesized { parenthesized } => { + parenthesized.map(|id| inspect_signature(state, context, id)).unwrap_or_else(unknown) + } + + _ => { + let variables = [].into(); + + let id = type_to_core(state, context, id); + let (arguments, result) = signature_components(state, id); + + InspectSignature { variables, arguments, result } + } + } +} + +fn signature_components(state: &mut CheckState, id: TypeId) -> (Vec, TypeId) { + let mut components = iter::successors(Some(id), |&id| match state.storage[id] { + Type::Function(_, id) => Some(id), + _ => None, + }) + .map(|id| match state.storage[id] { + Type::Function(id, _) => id, + _ => id, + }) + .collect_vec(); + + let Some(id) = components.pop() else { + unreachable!("invariant violated: expected non-empty components"); + }; + + (components, id) +} + +const MISSING_NAME: SmolStr = SmolStr::new_inline(""); pub fn convert_forall_binding( state: &mut CheckState, @@ -143,11 +216,11 @@ where Q: ExternalQueries, { let visible = binding.visible; - let name = binding.name.clone().unwrap_or(INVALID_NAME); + let name = binding.name.clone().unwrap_or(MISSING_NAME); let kind = match binding.kind { Some(id) => type_to_core(state, context, id), - None => state.fresh_unification(context), + None => state.fresh_unification_type(context), }; let level = state.bind_forall(binding.id, kind); diff --git a/compiler-core/checking/src/check/kind.rs b/compiler-core/checking/src/check/kind.rs index ad86d9c1..f933e02d 100644 --- a/compiler-core/checking/src/check/kind.rs +++ b/compiler-core/checking/src/check/kind.rs @@ -254,6 +254,6 @@ where Q: ExternalQueries, { let (inferred_type, inferred_kind) = infer_surface_kind(state, context, id); - unification::unify(state, context, inferred_kind, kind); + unification::subsumes(state, context, inferred_kind, kind); (inferred_type, inferred_kind) } diff --git a/compiler-core/checking/src/check/state.rs b/compiler-core/checking/src/check/state.rs index 06608b96..4f3c07be 100644 --- a/compiler-core/checking/src/check/state.rs +++ b/compiler-core/checking/src/check/state.rs @@ -97,8 +97,11 @@ where { pub queries: &'a Q, pub prim: PrimCore, + + pub id: FileId, pub indexed: Arc, pub lowered: Arc, + pub prim_indexed: Arc, } @@ -116,7 +119,7 @@ where let prim = PrimCore::collect(queries, state)?; let prim_id = queries.prim_id(); let prim_indexed = queries.indexed(prim_id)?; - Ok(CheckContext { queries, prim, indexed, lowered, prim_indexed }) + Ok(CheckContext { queries, prim, id, indexed, lowered, prim_indexed }) } } diff --git a/compiler-core/checking/src/check/unification.rs b/compiler-core/checking/src/check/unification.rs index e8c923fe..c2f26aaf 100644 --- a/compiler-core/checking/src/check/unification.rs +++ b/compiler-core/checking/src/check/unification.rs @@ -80,6 +80,16 @@ where solve(state, context, unification_id, t1).is_some() } + ( + &Type::Variable(Variable::Bound(t1_index)), + &Type::Variable(Variable::Bound(t2_index)), + ) => t1_index == t2_index, + + ( + &Type::Variable(Variable::Skolem(t1_level, _)), + &Type::Variable(Variable::Skolem(t2_level, _)), + ) => t1_level == t2_level, + _ => false, } } diff --git a/compiler-core/checking/src/lib.rs b/compiler-core/checking/src/lib.rs index f27e7467..387e45ea 100644 --- a/compiler-core/checking/src/lib.rs +++ b/compiler-core/checking/src/lib.rs @@ -9,13 +9,13 @@ use building_types::{QueryProxy, QueryResult}; use files::FileId; use indexing::{IndexedModule, TermItemId, TypeItemId}; use itertools::Itertools; -use lowering::{DataIr, LoweredModule, Scc, TermItemIr, TypeItemIr, TypeVariableBinding}; +use lowering::{DataIr, LoweredModule, Scc, TermItemIr, TypeItemIr}; use resolving::ResolvedModule; use rustc_hash::FxHashMap; +use smol_str::SmolStr; -use crate::check::kind::check_surface_kind; -use crate::check::{CheckContext, CheckState, convert, kind, transfer}; -use crate::core::{ForallBinder, Variable, debruijn, pretty}; +use crate::check::{CheckContext, CheckState, convert, kind, quantify, transfer, unification}; +use crate::core::{ForallBinder, Variable, debruijn}; pub trait ExternalQueries: QueryProxy< @@ -81,6 +81,8 @@ fn source_check_module( Ok(state.checked) } +const MISSING_NAME: SmolStr = SmolStr::new_static(""); + fn check_type_item(state: &mut CheckState, context: &CheckContext, item_id: TypeItemId) where Q: ExternalQueries, @@ -88,37 +90,137 @@ where let Some(item) = context.lowered.info.get_type_item(item_id) else { return }; match item { TypeItemIr::DataGroup { signature, data, .. } => { - let signature_type = signature - .map(|signature| convert::signature_type_to_core(state, context, signature)); - - if let Some(DataIr { variables }) = data { - let inferred_type = create_type_declaration_kind(state, context, variables); - for constructor_id in context.indexed.pairs.data_constructors(item_id) { - let Some(TermItemIr::Constructor { arguments }) = - context.lowered.info.get_term_item(constructor_id) - else { - continue; + let Some(DataIr { variables }) = data else { return }; + + let signature = signature.map(|id| convert::inspect_signature(state, context, id)); + + let (kind_variables, type_variables, result_kind) = if let Some(signature) = signature { + if variables.len() != signature.arguments.len() { + todo!("proper arity checking errors innit") + }; + + let variables = variables.iter(); + let arguments = signature.arguments.iter(); + + let kinds = variables.zip(arguments).map(|(variable, &argument)| { + // Use contravariant subtyping for type variables: + // + // data Example :: Argument -> Type + // data Example (a :: Variable) = Example + // + // Signature: Argument -> Type + // Inferred: Variable -> Type + // + // Given + // Variable -> Type <: Argument -> Type + // + // Therefore + // [Argument <: Variable, Type <: Type] + let kind = variable.kind.map_or(argument, |kind| { + let kind = convert::type_to_core(state, context, kind); + let valid = unification::subsumes(state, context, argument, kind); + if valid { kind } else { context.prim.unknown } + }); + + let name = variable.name.clone().unwrap_or(MISSING_NAME); + (variable.id, variable.visible, name, kind) + }); + + let kinds = kinds.collect_vec(); + + let kind_variables = signature.variables; + let result_kind = signature.result; + let type_variables = kinds.into_iter().map(|(id, visible, name, kind)| { + let level = state.bind_forall(id, kind); + ForallBinder { visible, name, level, kind } + }); + + (kind_variables, type_variables.collect_vec(), result_kind) + } else { + let kind_variables = vec![]; + let result_kind = context.prim.t; + let type_variables = variables.iter().map(|variable| { + let kind = match variable.kind { + Some(id) => convert::type_to_core(state, context, id), + None => state.fresh_unification_type(context), }; - for argument in arguments.iter() { - let (inferred_type, inferred_kind) = - check_surface_kind(state, context, *argument, context.prim.t); - { - let inferred_type = pretty::print_local(state, context, inferred_type); - let inferred_kind = pretty::print_local(state, context, inferred_kind); - eprintln!("{inferred_type} :: {inferred_kind}") - } - } - } - { - if let Some(signature_type) = signature_type { - let signature_type = pretty::print_local(state, context, signature_type); - eprintln!("{signature_type}"); - } - let inferred_type = pretty::print_local(state, context, inferred_type); - eprintln!("{inferred_type}"); - } + let visible = variable.visible; + let name = variable.name.clone().unwrap_or(MISSING_NAME); + let level = state.bind_forall(variable.id, kind); + ForallBinder { visible, name, level, kind } + }); + + (kind_variables, type_variables.collect_vec(), result_kind) + }; + + let data_reference = { + let size = state.bound.size(); + let reference_type = state.storage.intern(Type::Constructor(context.id, item_id)); + type_variables.iter().cloned().fold(reference_type, |reference_type, variable| { + let Some(index) = variable.level.to_index(size) else { + let level = variable.level; + unreachable!("invariant violated: invalid {level} for {size}"); + }; + + let variable = Variable::Bound(index); + let variable = state.storage.intern(Type::Variable(variable)); + + state.storage.intern(Type::Application(reference_type, variable)) + }) + }; + + for item_id in context.indexed.pairs.data_constructors(item_id) { + let Some(TermItemIr::Constructor { arguments }) = + context.lowered.info.get_term_item(item_id) + else { + continue; + }; + + let arguments = arguments.iter().map(|&argument| { + let (inferred_type, _) = + kind::check_surface_kind(state, context, argument, context.prim.t); + inferred_type + }); + + let arguments = arguments.collect_vec(); + + let constructor_type = + arguments.into_iter().rfold(data_reference, |result, argument| { + state.storage.intern(Type::Function(argument, result)) + }); + + let all_variables = { + let from_kind = kind_variables.iter(); + let from_type = type_variables.iter(); + from_kind.chain(from_type).cloned() + }; + + let constructor_type = all_variables.rfold(constructor_type, |inner, variable| { + state.storage.intern(Type::Forall(variable, inner)) + }); + + let Some(constructor_type) = quantify::quantify(state, constructor_type) else { + continue; + }; + + let constructor_type = transfer::globalize(state, context, constructor_type); + state.checked.terms.insert(item_id, constructor_type); } + + let type_kind = { + let data_kind = type_variables.iter().rfold(result_kind, |result, variable| { + state.storage.intern(Type::Function(variable.kind, result)) + }); + kind_variables.iter().cloned().rfold(data_kind, |inner, binder| { + state.storage.intern(Type::Forall(binder, inner)) + }) + }; + + if let Some(data_kind) = quantify::quantify(state, type_kind) { + let data_kind = transfer::globalize(state, context, data_kind); + state.checked.types.insert(item_id, data_kind); + }; } TypeItemIr::NewtypeGroup { .. } => (), @@ -139,43 +241,6 @@ where } } -fn create_type_declaration_kind( - state: &mut CheckState, - context: &CheckContext, - bindings: &[TypeVariableBinding], -) -> TypeId -where - Q: ExternalQueries, -{ - let binders = bindings - .iter() - .map(|binding| convert::convert_forall_binding(state, context, binding)) - .collect_vec(); - - // Build the function type for the type declaration e.g. - // - // ```purescript - // data Maybe a = Just a | Nothing - // ``` - // - // function_type := a -> Type - let size = state.bound.size(); - let function_type = binders.iter().rfold(context.prim.t, |result, binder| { - let index = binder.level.to_index(size).unwrap_or_else(|| { - unreachable!("invariant violated: invalid {} for {size}", binder.level) - }); - let variable = state.storage.intern(Type::Variable(Variable::Bound(index))); - state.storage.intern(Type::Function(variable, result)) - }); - - // Qualify the type variables in the function type e.g. - // - // forall (a :: Type). a -> Type - binders - .into_iter() - .rfold(function_type, |inner, binder| state.storage.intern(Type::Forall(binder, inner))) -} - fn prim_check_module( queries: &impl ExternalQueries, file_id: FileId, diff --git a/compiler-core/lowering/src/algorithm/recursive.rs b/compiler-core/lowering/src/algorithm/recursive.rs index 0c9d0439..43ece09c 100644 --- a/compiler-core/lowering/src/algorithm/recursive.rs +++ b/compiler-core/lowering/src/algorithm/recursive.rs @@ -701,8 +701,8 @@ fn lower_type_kind( s.push_forall_scope(); let bindings = cst.children().map(|cst| lower_type_variable_binding(s, context, &cst)).collect(); - let type_ = cst.type_().map(|cst| lower_type(s, context, &cst)); - TypeKind::Forall { bindings, type_ } + let inner = cst.type_().map(|cst| lower_type(s, context, &cst)); + TypeKind::Forall { bindings, inner } }), cst::Type::TypeHole(_) => TypeKind::Hole, cst::Type::TypeInteger(_) => TypeKind::Integer, @@ -779,8 +779,8 @@ pub(crate) fn lower_forall(state: &mut State, context: &Context, cst: &cst::Type state.push_forall_scope(); let bindings = f.children().map(|cst| lower_type_variable_binding(state, context, &cst)).collect(); - let type_ = f.type_().map(|cst| lower_forall(state, context, &cst)); - let kind = TypeKind::Forall { bindings, type_ }; + let inner = f.type_().map(|cst| lower_forall(state, context, &cst)); + let kind = TypeKind::Forall { bindings, inner }; state.associate_type_info(id, kind); id } else { diff --git a/compiler-core/lowering/src/intermediate.rs b/compiler-core/lowering/src/intermediate.rs index ef2fa17f..632509fa 100644 --- a/compiler-core/lowering/src/intermediate.rs +++ b/compiler-core/lowering/src/intermediate.rs @@ -169,7 +169,7 @@ pub enum TypeKind { Arrow { argument: Option, result: Option }, Constrained { constraint: Option, constrained: Option }, Constructor { resolution: Option<(FileId, TypeItemId)> }, - Forall { bindings: Arc<[TypeVariableBinding]>, type_: Option }, + Forall { bindings: Arc<[TypeVariableBinding]>, inner: Option }, Hole, Integer, Kinded { type_: Option, kind: Option }, diff --git a/tests-integration/tests/checking.rs b/tests-integration/tests/checking.rs index ec8aad0f..0fe3c7eb 100644 --- a/tests-integration/tests/checking.rs +++ b/tests-integration/tests/checking.rs @@ -6,6 +6,7 @@ use checking::check::unification::{self, UnificationState}; use checking::check::{CheckContext, CheckState, quantify}; use checking::core::{ForallBinder, Type, TypeId, Variable, debruijn, pretty}; use files::{FileId, Files}; +use indexing::{TermItem, TypeItem}; use lowering::TypeVariableBindingId; struct ContextState<'r> { @@ -362,3 +363,95 @@ fn test_subsumes_nested_forall() { let result = unification::subsumes(state, context, forall_a_b, int_to_string_to_int); assert!(result, "∀a. ∀b. (a -> b -> a) should subsume (Int -> String -> Int)"); } + +#[test] +fn test_manual() { + let (engine, id) = empty_engine(); + engine.set_content( + id, + r#" +module Main where + +data Proxy :: forall k. k -> Type +data Proxy a = Proxy +"#, + ); + + let indexed = engine.indexed(id).unwrap(); + let checked = engine.checked(id).unwrap(); + + for (id, TermItem { name, .. }) in indexed.items.iter_terms() { + let Some(n) = name else { continue }; + let Some(t) = checked.lookup_term(id) else { continue }; + let t = pretty::print_global(&engine, t); + eprintln!("{n} :: {t}") + } + + for (id, TypeItem { name, .. }) in indexed.items.iter_types() { + let Some(n) = name else { continue }; + let Some(t) = checked.lookup_type(id) else { continue }; + let t = pretty::print_global(&engine, t); + eprintln!("{n} :: {t}") + } +} + +#[test] +fn test_manual_2() { + let (engine, id) = empty_engine(); + engine.set_content( + id, + r#" +module Main where + +data Maybe :: Type -> Type +data Maybe (a :: Type) = Just a | Nothing +"#, + ); + + let indexed = engine.indexed(id).unwrap(); + let checked = engine.checked(id).unwrap(); + + for (id, TermItem { name, .. }) in indexed.items.iter_terms() { + let Some(n) = name else { continue }; + let Some(t) = checked.lookup_term(id) else { continue }; + let t = pretty::print_global(&engine, t); + eprintln!("{n} :: {t}") + } + + for (id, TypeItem { name, .. }) in indexed.items.iter_types() { + let Some(n) = name else { continue }; + let Some(t) = checked.lookup_type(id) else { continue }; + let t = pretty::print_global(&engine, t); + eprintln!("{n} :: {t}") + } +} + +#[test] +fn test_manual_3() { + let (engine, id) = empty_engine(); + engine.set_content( + id, + r#" +module Main where + +data Proxy a = Proxy +"#, + ); + + let indexed = engine.indexed(id).unwrap(); + let checked = engine.checked(id).unwrap(); + + for (id, TermItem { name, .. }) in indexed.items.iter_terms() { + let Some(n) = name else { continue }; + let Some(t) = checked.lookup_term(id) else { continue }; + let t = pretty::print_global(&engine, t); + eprintln!("{n} :: {t}") + } + + for (id, TypeItem { name, .. }) in indexed.items.iter_types() { + let Some(n) = name else { continue }; + let Some(t) = checked.lookup_type(id) else { continue }; + let t = pretty::print_global(&engine, t); + eprintln!("{n} :: {t}") + } +} From 1e54adf837038a4a1e42cfc82560587721a3b532 Mon Sep 17 00:00:00 2001 From: Justin Garcia Date: Tue, 18 Nov 2025 00:33:18 +0800 Subject: [PATCH 9/9] Rename data_kind to type_kind --- compiler-core/checking/src/lib.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/compiler-core/checking/src/lib.rs b/compiler-core/checking/src/lib.rs index 387e45ea..e460abea 100644 --- a/compiler-core/checking/src/lib.rs +++ b/compiler-core/checking/src/lib.rs @@ -209,17 +209,17 @@ where } let type_kind = { - let data_kind = type_variables.iter().rfold(result_kind, |result, variable| { + let type_kind = type_variables.iter().rfold(result_kind, |result, variable| { state.storage.intern(Type::Function(variable.kind, result)) }); - kind_variables.iter().cloned().rfold(data_kind, |inner, binder| { + kind_variables.iter().cloned().rfold(type_kind, |inner, binder| { state.storage.intern(Type::Forall(binder, inner)) }) }; - if let Some(data_kind) = quantify::quantify(state, type_kind) { - let data_kind = transfer::globalize(state, context, data_kind); - state.checked.types.insert(item_id, data_kind); + if let Some(type_kind) = quantify::quantify(state, type_kind) { + let type_kind = transfer::globalize(state, context, type_kind); + state.checked.types.insert(item_id, type_kind); }; }