From 7b15e893a6e826541db5e9a7175e500c75dc6fc6 Mon Sep 17 00:00:00 2001 From: Daniel Harvey Date: Wed, 12 Apr 2023 08:52:55 +0100 Subject: [PATCH 01/13] Add unused tuple types --- llvm-calc4/src/Calc/Compile/ToLLVM.hs | 3 +++ llvm-calc4/src/Calc/ExprUtils.hs | 3 +++ llvm-calc4/src/Calc/Interpreter.hs | 1 + llvm-calc4/src/Calc/TypeUtils.hs | 2 ++ llvm-calc4/src/Calc/Typecheck/Elaborate.hs | 1 + llvm-calc4/src/Calc/Types/Expr.hs | 7 +++++++ llvm-calc4/src/Calc/Types/Type.hs | 11 +++++++++++ 7 files changed, 28 insertions(+) diff --git a/llvm-calc4/src/Calc/Compile/ToLLVM.hs b/llvm-calc4/src/Calc/Compile/ToLLVM.hs index 0eba42e9..8257fc52 100644 --- a/llvm-calc4/src/Calc/Compile/ToLLVM.hs +++ b/llvm-calc4/src/Calc/Compile/ToLLVM.hs @@ -80,6 +80,7 @@ lookupArg identifier = do printFunction :: (LLVM.MonadModuleBuilder m) => Type ann -> m LLVM.Operand printFunction (TPrim _ TInt) = LLVM.extern "printint" [LLVM.i32] LLVM.void printFunction (TPrim _ TBool) = LLVM.extern "printbool" [LLVM.i1] LLVM.void +printFunction (TTuple {}) = error "printFunction TTuple" printFunction (TFunction _ _ tyRet) = printFunction tyRet -- maybe this should be an error instead -- | given our `Module` type, turn it into an LLVM module @@ -152,6 +153,7 @@ functionNameToLLVM (FunctionName fnName) = typeToLLVM :: Type ann -> LLVM.Type typeToLLVM (TPrim _ TBool) = LLVM.i1 typeToLLVM (TPrim _ TInt) = LLVM.i32 +typeToLLVM TTuple {} = error "typeToLLVM TTuple" typeToLLVM (TFunction _ tyArgs tyRet) = LLVM.FunctionType (typeToLLVM tyRet) (typeToLLVM <$> tyArgs) False @@ -243,6 +245,7 @@ exprToLLVM (EPrim _ prim) = pure $ primToLLVM prim exprToLLVM (EVar _ var) = lookupArg var +exprToLLVM (ETuple {}) = error "exprToLLVM ETuple" exprToLLVM (EApply _ fnName args) = do irFunc <- lookupFunction fnName irArgs <- traverse exprToLLVM args diff --git a/llvm-calc4/src/Calc/ExprUtils.hs b/llvm-calc4/src/Calc/ExprUtils.hs index 471e13b3..830ebc88 100644 --- a/llvm-calc4/src/Calc/ExprUtils.hs +++ b/llvm-calc4/src/Calc/ExprUtils.hs @@ -17,6 +17,7 @@ getOuterAnnotation (EPrim ann _) = ann getOuterAnnotation (EIf ann _ _ _) = ann getOuterAnnotation (EVar ann _) = ann getOuterAnnotation (EApply ann _ _) = ann +getOuterAnnotation (ETuple ann _ _ ) = ann -- | modify the outer annotation of an expression -- useful for adding line numbers during parsing @@ -28,6 +29,7 @@ mapOuterExprAnnotation f expr' = EIf ann a b c -> EIf (f ann) a b c EVar ann a -> EVar (f ann) a EApply ann a b -> EApply (f ann) a b + ETuple ann a b -> ETuple (f ann) a b -- | Given a function that changes `Expr` values, apply it throughout -- an AST tree @@ -38,3 +40,4 @@ mapExpr _ (EVar ann a) = EVar ann a mapExpr f (EApply ann fn args) = EApply ann fn (f <$> args) mapExpr f (EIf ann predExpr thenExpr elseExpr) = EIf ann (f predExpr) (f thenExpr) (f elseExpr) +mapExpr f (ETuple ann a as) = ETuple ann (f a) (f <$> as) diff --git a/llvm-calc4/src/Calc/Interpreter.hs b/llvm-calc4/src/Calc/Interpreter.hs index b45e8ad1..2f7c08bf 100644 --- a/llvm-calc4/src/Calc/Interpreter.hs +++ b/llvm-calc4/src/Calc/Interpreter.hs @@ -126,6 +126,7 @@ interpret (EApply _ fnName args) = interpretApply fnName args interpret (EInfix ann op a b) = interpretInfix ann op a b +interpret (ETuple {}) = error "interpret ETuple" interpret (EIf ann predExpr thenExpr elseExpr) = do predA <- interpret predExpr case predA of diff --git a/llvm-calc4/src/Calc/TypeUtils.hs b/llvm-calc4/src/Calc/TypeUtils.hs index 3b84345c..5900737a 100644 --- a/llvm-calc4/src/Calc/TypeUtils.hs +++ b/llvm-calc4/src/Calc/TypeUtils.hs @@ -5,7 +5,9 @@ import Calc.Types.Type getOuterTypeAnnotation :: Type ann -> ann getOuterTypeAnnotation (TPrim ann _) = ann getOuterTypeAnnotation (TFunction ann _ _) = ann +getOuterTypeAnnotation (TTuple ann _ _ ) = ann mapOuterTypeAnnotation :: (ann -> ann) -> Type ann -> Type ann mapOuterTypeAnnotation f (TPrim ann p) = TPrim (f ann) p mapOuterTypeAnnotation f (TFunction ann a b) = TFunction (f ann) a b +mapOuterTypeAnnotation f (TTuple ann a b) = TTuple (f ann) a b diff --git a/llvm-calc4/src/Calc/Typecheck/Elaborate.hs b/llvm-calc4/src/Calc/Typecheck/Elaborate.hs index ebb2bd31..d548c94c 100644 --- a/llvm-calc4/src/Calc/Typecheck/Elaborate.hs +++ b/llvm-calc4/src/Calc/Typecheck/Elaborate.hs @@ -128,6 +128,7 @@ infer (EPrim ann prim) = pure (EPrim (typeFromPrim ann prim) prim) infer (EIf ann predExpr thenExpr elseExpr) = inferIf ann predExpr thenExpr elseExpr +infer (ETuple {}) = error "infer ETuple" infer (EApply ann fnName args) = do fn <- lookupFunction ann fnName (ty, elabArgs) <- case fn of diff --git a/llvm-calc4/src/Calc/Types/Expr.hs b/llvm-calc4/src/Calc/Types/Expr.hs index ae18cd21..87fb90f3 100644 --- a/llvm-calc4/src/Calc/Types/Expr.hs +++ b/llvm-calc4/src/Calc/Types/Expr.hs @@ -9,6 +9,7 @@ import Calc.Types.Identifier import Calc.Types.Prim import Prettyprinter ((<+>)) import qualified Prettyprinter as PP +import qualified Data.List.NonEmpty as NE data Expr ann = EPrim ann Prim @@ -16,6 +17,7 @@ data Expr ann | EIf ann (Expr ann) (Expr ann) (Expr ann) | EVar ann Identifier | EApply ann FunctionName [Expr ann] + | ETuple ann (Expr ann) (NE.NonEmpty (Expr ann)) deriving stock (Eq, Ord, Show, Functor, Foldable, Traversable) -- when on multilines, indent by `i`, if not then nothing @@ -35,6 +37,11 @@ instance PP.Pretty (Expr ann) where pretty (EApply _ fn args) = PP.pretty fn <> "(" <> PP.cat pArgs <> ")" where pArgs = PP.punctuate "," (PP.pretty <$> args) + pretty (ETuple _ a as) = + "(" <> PP.cat (PP.punctuate "," (PP.pretty <$> tupleItems a as)) <> ")" + where + tupleItems :: a -> NE.NonEmpty a -> [a] + tupleItems b bs = b : NE.toList bs data Op = OpAdd diff --git a/llvm-calc4/src/Calc/Types/Type.hs b/llvm-calc4/src/Calc/Types/Type.hs index 5e959b9f..a31d3ba8 100644 --- a/llvm-calc4/src/Calc/Types/Type.hs +++ b/llvm-calc4/src/Calc/Types/Type.hs @@ -5,6 +5,7 @@ module Calc.Types.Type (Type (..), TypePrim (..)) where import qualified Prettyprinter as PP +import qualified Data.List.NonEmpty as NE data TypePrim = TBool | TInt deriving stock (Eq, Ord, Show) @@ -16,6 +17,7 @@ instance PP.Pretty TypePrim where data Type ann = TPrim ann TypePrim | TFunction ann [Type ann] (Type ann) + | TTuple ann (Type ann) (NE.NonEmpty (Type ann)) deriving stock (Eq, Ord, Show, Functor) instance PP.Pretty (Type ann) where @@ -24,3 +26,12 @@ instance PP.Pretty (Type ann) where "(" <> prettyArgs <> ") -> " <> PP.pretty ret where prettyArgs = PP.concatWith (PP.surround PP.comma) (PP.pretty <$> args) + pretty (TTuple _ a as) = + "(" <> PP.cat (PP.punctuate "," (PP.pretty <$> tupleItems a as)) <> ")" + where + tupleItems :: a -> NE.NonEmpty a -> [a] + tupleItems b bs = b : NE.toList bs + + + + From 0869812d19d304235b9368abf48291d69d54f3cc Mon Sep 17 00:00:00 2001 From: Daniel Harvey Date: Wed, 12 Apr 2023 09:00:12 +0100 Subject: [PATCH 02/13] Broken parser tests for tuple --- llvm-calc4/test/Test/Parser/ParserSpec.hs | 25 ++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/llvm-calc4/test/Test/Parser/ParserSpec.hs b/llvm-calc4/test/Test/Parser/ParserSpec.hs index 25fc3b20..23a29949 100644 --- a/llvm-calc4/test/Test/Parser/ParserSpec.hs +++ b/llvm-calc4/test/Test/Parser/ParserSpec.hs @@ -1,7 +1,8 @@ {-# LANGUAGE OverloadedStrings #-} - + {-# LANGUAGE LambdaCase #-} module Test.Parser.ParserSpec (spec) where +import qualified Data.List.NonEmpty as NE import Calc import Data.Foldable (traverse_) import Data.Functor @@ -18,13 +19,30 @@ bool = EPrim mempty . PBool var :: (Monoid ann) => String -> Expr ann var = EVar mempty . Identifier . fromString +tuple :: (Monoid ann) => [Expr ann] -> Expr ann +tuple = \case + (a : b : rest) -> ETuple mempty a (b NE.:| rest) + _ -> error "not enough items for tuple" + +tyInt :: (Monoid ann) => Type ann +tyInt = TPrim mempty TInt + +tyBool :: (Monoid ann) => Type ann +tyBool = TPrim mempty TBool + +tyTuple :: (Monoid ann) => [Type ann] -> Type ann +tyTuple = \case + (a : b : rest) -> TTuple mempty a (b NE.:| rest) + _ -> error "not enough items for tyTuple" + spec :: Spec spec = do describe "ParserSpec" $ do describe "Type" $ do let strings = - [ ("Boolean", TPrim () TBool), - ("Integer", TPrim () TInt) + [ ("Boolean", tyBool), + ("Integer", tyInt), + ("(Boolean, Boolean, Integer)", tyTuple [tyBool, tyBool, tyInt]) ] traverse_ ( \(str, expr) -> it (T.unpack str) $ do @@ -83,6 +101,7 @@ spec = do ("1 + 2", EInfix () OpAdd (int 1) (int 2)), ("True", EPrim () (PBool True)), ("False", EPrim () (PBool False)), + ("(1,2,True)", tuple [int 1, int 2, bool True]), ( "1 + 2 + 3", EInfix () From 98599aa94bb0b5a14b6db2f2b4000c00c44c060c Mon Sep 17 00:00:00 2001 From: Daniel Harvey Date: Wed, 12 Apr 2023 22:06:50 +0100 Subject: [PATCH 03/13] Ready to typecheck --- llvm-calc4/src/Calc/Compile/ToLLVM.hs | 1 + llvm-calc4/src/Calc/ExprUtils.hs | 7 +- llvm-calc4/src/Calc/Interpreter.hs | 1 + llvm-calc4/src/Calc/Parser/Expr.hs | 51 +++++++++++++- llvm-calc4/src/Calc/Parser/Pattern.hs | 66 +++++++++++++++++++ llvm-calc4/src/Calc/Parser/Primitives.hs | 15 ++++- llvm-calc4/src/Calc/Parser/Shared.hs | 6 +- llvm-calc4/src/Calc/Parser/Type.hs | 18 ++++- llvm-calc4/src/Calc/TypeUtils.hs | 2 +- llvm-calc4/src/Calc/Typecheck/Elaborate.hs | 1 + llvm-calc4/src/Calc/Types.hs | 2 + llvm-calc4/src/Calc/Types/Expr.hs | 50 ++++++++++++-- llvm-calc4/src/Calc/Types/Pattern.hs | 34 ++++++++++ llvm-calc4/src/Calc/Types/Type.hs | 12 ++-- llvm-calc4/test/Test/Parser/ParserSpec.hs | 26 +++++++- .../test/Test/Typecheck/TypecheckSpec.hs | 4 +- 16 files changed, 270 insertions(+), 26 deletions(-) create mode 100644 llvm-calc4/src/Calc/Parser/Pattern.hs create mode 100644 llvm-calc4/src/Calc/Types/Pattern.hs diff --git a/llvm-calc4/src/Calc/Compile/ToLLVM.hs b/llvm-calc4/src/Calc/Compile/ToLLVM.hs index 8257fc52..3d39d722 100644 --- a/llvm-calc4/src/Calc/Compile/ToLLVM.hs +++ b/llvm-calc4/src/Calc/Compile/ToLLVM.hs @@ -246,6 +246,7 @@ exprToLLVM (EPrim _ prim) = exprToLLVM (EVar _ var) = lookupArg var exprToLLVM (ETuple {}) = error "exprToLLVM ETuple" +exprToLLVM (EPatternMatch {}) = error "exprToLLVM EPatternMatch" exprToLLVM (EApply _ fnName args) = do irFunc <- lookupFunction fnName irArgs <- traverse exprToLLVM args diff --git a/llvm-calc4/src/Calc/ExprUtils.hs b/llvm-calc4/src/Calc/ExprUtils.hs index 830ebc88..bc1be3dd 100644 --- a/llvm-calc4/src/Calc/ExprUtils.hs +++ b/llvm-calc4/src/Calc/ExprUtils.hs @@ -8,6 +8,7 @@ module Calc.ExprUtils where import Calc.Types +import Data.Bifunctor (second) -- | get the annotation in the first leaf found in an `Expr`. -- useful for getting the overall type of an expression @@ -17,7 +18,8 @@ getOuterAnnotation (EPrim ann _) = ann getOuterAnnotation (EIf ann _ _ _) = ann getOuterAnnotation (EVar ann _) = ann getOuterAnnotation (EApply ann _ _) = ann -getOuterAnnotation (ETuple ann _ _ ) = ann +getOuterAnnotation (ETuple ann _ _) = ann +getOuterAnnotation (EPatternMatch ann _ _) = ann -- | modify the outer annotation of an expression -- useful for adding line numbers during parsing @@ -30,6 +32,7 @@ mapOuterExprAnnotation f expr' = EVar ann a -> EVar (f ann) a EApply ann a b -> EApply (f ann) a b ETuple ann a b -> ETuple (f ann) a b + EPatternMatch ann a b -> EPatternMatch (f ann) a b -- | Given a function that changes `Expr` values, apply it throughout -- an AST tree @@ -41,3 +44,5 @@ mapExpr f (EApply ann fn args) = EApply ann fn (f <$> args) mapExpr f (EIf ann predExpr thenExpr elseExpr) = EIf ann (f predExpr) (f thenExpr) (f elseExpr) mapExpr f (ETuple ann a as) = ETuple ann (f a) (f <$> as) +mapExpr f (EPatternMatch ann matchExpr patterns) = + EPatternMatch ann (f matchExpr) (fmap (second f) patterns) diff --git a/llvm-calc4/src/Calc/Interpreter.hs b/llvm-calc4/src/Calc/Interpreter.hs index 2f7c08bf..e9ec956f 100644 --- a/llvm-calc4/src/Calc/Interpreter.hs +++ b/llvm-calc4/src/Calc/Interpreter.hs @@ -127,6 +127,7 @@ interpret (EApply _ fnName args) = interpret (EInfix ann op a b) = interpretInfix ann op a b interpret (ETuple {}) = error "interpret ETuple" +interpret (EPatternMatch {}) = error "interpret EPatternMatch" interpret (EIf ann predExpr thenExpr elseExpr) = do predA <- interpret predExpr case predA of diff --git a/llvm-calc4/src/Calc/Parser/Expr.hs b/llvm-calc4/src/Calc/Parser/Expr.hs index 34dadba3..a5f5313c 100644 --- a/llvm-calc4/src/Calc/Parser/Expr.hs +++ b/llvm-calc4/src/Calc/Parser/Expr.hs @@ -3,12 +3,14 @@ module Calc.Parser.Expr (exprParser) where import Calc.Parser.Identifier +import Calc.Parser.Pattern import Calc.Parser.Primitives import Calc.Parser.Shared import Calc.Parser.Types import Calc.Types.Annotation import Calc.Types.Expr import Control.Monad.Combinators.Expr +import qualified Data.List.NonEmpty as NE import Data.Text import Text.Megaparsec @@ -17,8 +19,10 @@ exprParser = addLocation (makeExprParser exprPart table) "expression" exprPart :: Parser (Expr Annotation) exprPart = - inBrackets (addLocation exprParser) - <|> primParser + try tupleParser + <|> inBrackets (addLocation exprParser) + <|> patternMatchParser + <|> primExprParser <|> ifParser <|> try applyParser <|> varParser @@ -55,3 +59,46 @@ applyParser = addLocation $ do args <- sepBy exprParser (stringLiteral ",") stringLiteral ")" pure (EApply mempty fnName args) + +tupleParser :: Parser (Expr Annotation) +tupleParser = label "tuple" $ + addLocation $ do + _ <- stringLiteral "(" + neArgs <- NE.fromList <$> sepBy1 exprParser (stringLiteral ",") + neTail <- case NE.nonEmpty (NE.tail neArgs) of + Just ne -> pure ne + _ -> fail "Expected at least two items in a tuple" + _ <- stringLiteral ")" + pure (ETuple mempty (NE.head neArgs) neTail) + +----- + +patternMatchParser :: Parser ParserExpr +patternMatchParser = addLocation $ do + matchExpr <- matchExprWithParser + patterns <- + try patternMatchesParser + <|> pure <$> patternCaseParser + case NE.nonEmpty patterns of + (Just nePatterns) -> pure $ EPatternMatch mempty matchExpr nePatterns + _ -> error "need at least one pattern" + +matchExprWithParser :: Parser ParserExpr +matchExprWithParser = do + stringLiteral "case" + sumExpr <- exprParser + stringLiteral "of" + pure sumExpr + +patternMatchesParser :: Parser [(ParserPattern, ParserExpr)] +patternMatchesParser = + sepBy + patternCaseParser + (stringLiteral "|") + +patternCaseParser :: Parser (ParserPattern, ParserExpr) +patternCaseParser = do + pat <- orInBrackets patternParser + stringLiteral "->" + patExpr <- exprParser + pure (pat, patExpr) diff --git a/llvm-calc4/src/Calc/Parser/Pattern.hs b/llvm-calc4/src/Calc/Parser/Pattern.hs new file mode 100644 index 00000000..043b8ae2 --- /dev/null +++ b/llvm-calc4/src/Calc/Parser/Pattern.hs @@ -0,0 +1,66 @@ +{-# LANGUAGE OverloadedStrings #-} + +module Calc.Parser.Pattern + ( patternParser, + ParserPattern, + ) +where + +import Calc.Parser.Primitives +import Calc.Parser.Identifier +import qualified Data.List.NonEmpty as NE +import Calc.Types.Pattern +import Calc.Parser.Shared +import Calc.Types +import Text.Megaparsec +import Text.Megaparsec.Char +import Calc.Parser.Types + +type ParserPattern = Pattern Annotation + +patternParser :: Parser ParserPattern +patternParser = + label + "pattern match" + ( orInBrackets + ( + try patTupleParser + <|> try patWildcardParser + <|> try patVariableParser + <|> patLitParser + ) + ) + +---- + +patWildcardParser :: Parser ParserPattern +patWildcardParser = + myLexeme $ + withLocation + (\loc _ -> PWildcard loc) + (string "_") + +---- + +patVariableParser :: Parser ParserPattern +patVariableParser = + myLexeme $ withLocation PVar identifierParser + +---- + +patTupleParser :: Parser ParserPattern +patTupleParser = label "tuple" $ + withLocation (\loc (pHead, pTail) -> PTuple loc pHead pTail) $ do + _ <- stringLiteral "(" + neArgs <- NE.fromList <$> sepBy1 patternParser (stringLiteral ",") + neTail <- case NE.nonEmpty (NE.tail neArgs) of + Just ne -> pure ne + _ -> fail "Expected at least two items in a tuple" + _ <- stringLiteral ")" + pure (NE.head neArgs, neTail) + +---- + +patLitParser :: Parser ParserPattern +patLitParser = withLocation PLiteral primParser + diff --git a/llvm-calc4/src/Calc/Parser/Primitives.hs b/llvm-calc4/src/Calc/Parser/Primitives.hs index ff63ead8..11ce8957 100644 --- a/llvm-calc4/src/Calc/Parser/Primitives.hs +++ b/llvm-calc4/src/Calc/Parser/Primitives.hs @@ -1,7 +1,8 @@ {-# LANGUAGE OverloadedStrings #-} module Calc.Parser.Primitives - ( primParser, + ( primExprParser, + primParser, intParser, ) where @@ -37,10 +38,18 @@ falseParser = stringLiteral "False" $> False --- -primParser :: Parser ParserExpr -primParser = +primExprParser :: Parser ParserExpr +primExprParser = myLexeme $ addLocation $ EPrim mempty . PInt <$> intParser <|> EPrim mempty <$> truePrimParser <|> EPrim mempty <$> falsePrimParser + +---- + +primParser :: Parser Prim +primParser = + PInt <$> intParser + <|> truePrimParser + <|> falsePrimParser diff --git a/llvm-calc4/src/Calc/Parser/Shared.hs b/llvm-calc4/src/Calc/Parser/Shared.hs index a57af351..0f3a8c4e 100644 --- a/llvm-calc4/src/Calc/Parser/Shared.hs +++ b/llvm-calc4/src/Calc/Parser/Shared.hs @@ -1,7 +1,8 @@ {-# LANGUAGE OverloadedStrings #-} module Calc.Parser.Shared - ( inBrackets, + ( orInBrackets, + inBrackets, myLexeme, withLocation, stringLiteral, @@ -47,6 +48,9 @@ addTypeLocation = withLocation (mapOuterTypeAnnotation . const) inBrackets :: Parser a -> Parser a inBrackets = between2 '(' ')' +orInBrackets :: Parser a -> Parser a +orInBrackets parser = try parser <|> try (inBrackets parser) + myLexeme :: Parser a -> Parser a myLexeme = L.lexeme (L.space space1 empty empty) diff --git a/llvm-calc4/src/Calc/Parser/Type.hs b/llvm-calc4/src/Calc/Parser/Type.hs index cbaacc2d..29f6673c 100644 --- a/llvm-calc4/src/Calc/Parser/Type.hs +++ b/llvm-calc4/src/Calc/Parser/Type.hs @@ -10,19 +10,33 @@ import Calc.Parser.Shared import Calc.Parser.Types import Calc.Types.Type import Data.Functor (($>)) +import qualified Data.List.NonEmpty as NE import Text.Megaparsec ( MonadParsec (try), + label, + sepBy1, (<|>), ) -- | top-level parser for type signatures typeParser :: Parser ParserType typeParser = - myLexeme (addTypeLocation tyPrimitiveParser) + tyPrimitiveParser <|> tyTupleParser tyPrimitiveParser :: Parser ParserType -tyPrimitiveParser = TPrim mempty <$> tyPrimParser +tyPrimitiveParser = myLexeme $ addTypeLocation $ TPrim mempty <$> tyPrimParser where tyPrimParser = try (stringLiteral "Boolean" $> TBool) <|> try (stringLiteral "Integer" $> TInt) + +tyTupleParser :: Parser ParserType +tyTupleParser = label "tuple" $ + addTypeLocation $ do + _ <- stringLiteral "(" + neArgs <- NE.fromList <$> sepBy1 typeParser (stringLiteral ",") + neTail <- case NE.nonEmpty (NE.tail neArgs) of + Just ne -> pure ne + _ -> fail "Expected at least two items in a tuple" + _ <- stringLiteral ")" + pure (TTuple mempty (NE.head neArgs) neTail) diff --git a/llvm-calc4/src/Calc/TypeUtils.hs b/llvm-calc4/src/Calc/TypeUtils.hs index 5900737a..53825e4f 100644 --- a/llvm-calc4/src/Calc/TypeUtils.hs +++ b/llvm-calc4/src/Calc/TypeUtils.hs @@ -5,7 +5,7 @@ import Calc.Types.Type getOuterTypeAnnotation :: Type ann -> ann getOuterTypeAnnotation (TPrim ann _) = ann getOuterTypeAnnotation (TFunction ann _ _) = ann -getOuterTypeAnnotation (TTuple ann _ _ ) = ann +getOuterTypeAnnotation (TTuple ann _ _) = ann mapOuterTypeAnnotation :: (ann -> ann) -> Type ann -> Type ann mapOuterTypeAnnotation f (TPrim ann p) = TPrim (f ann) p diff --git a/llvm-calc4/src/Calc/Typecheck/Elaborate.hs b/llvm-calc4/src/Calc/Typecheck/Elaborate.hs index d548c94c..9c9223e8 100644 --- a/llvm-calc4/src/Calc/Typecheck/Elaborate.hs +++ b/llvm-calc4/src/Calc/Typecheck/Elaborate.hs @@ -129,6 +129,7 @@ infer (EPrim ann prim) = infer (EIf ann predExpr thenExpr elseExpr) = inferIf ann predExpr thenExpr elseExpr infer (ETuple {}) = error "infer ETuple" +infer (EPatternMatch {}) = error "infer EPatternMatch" infer (EApply ann fnName args) = do fn <- lookupFunction ann fnName (ty, elabArgs) <- case fn of diff --git a/llvm-calc4/src/Calc/Types.hs b/llvm-calc4/src/Calc/Types.hs index 2e4bf622..2f428857 100644 --- a/llvm-calc4/src/Calc/Types.hs +++ b/llvm-calc4/src/Calc/Types.hs @@ -6,6 +6,7 @@ module Calc.Types module Calc.Types.Module, module Calc.Types.Prim, module Calc.Types.Type, + module Calc.Types.Pattern, ) where @@ -14,5 +15,6 @@ import Calc.Types.Expr import Calc.Types.Function import Calc.Types.Identifier import Calc.Types.Module +import Calc.Types.Pattern import Calc.Types.Prim import Calc.Types.Type diff --git a/llvm-calc4/src/Calc/Types/Expr.hs b/llvm-calc4/src/Calc/Types/Expr.hs index 87fb90f3..0b88f8d0 100644 --- a/llvm-calc4/src/Calc/Types/Expr.hs +++ b/llvm-calc4/src/Calc/Types/Expr.hs @@ -6,10 +6,11 @@ module Calc.Types.Expr (Expr (..), Op (..)) where import Calc.Types.FunctionName import Calc.Types.Identifier +import Calc.Types.Pattern import Calc.Types.Prim +import qualified Data.List.NonEmpty as NE import Prettyprinter ((<+>)) import qualified Prettyprinter as PP -import qualified Data.List.NonEmpty as NE data Expr ann = EPrim ann Prim @@ -18,6 +19,7 @@ data Expr ann | EVar ann Identifier | EApply ann FunctionName [Expr ann] | ETuple ann (Expr ann) (NE.NonEmpty (Expr ann)) + | EPatternMatch ann (Expr ann) (NE.NonEmpty (Pattern ann, Expr ann)) deriving stock (Eq, Ord, Show, Functor, Foldable, Traversable) -- when on multilines, indent by `i`, if not then nothing @@ -39,9 +41,49 @@ instance PP.Pretty (Expr ann) where pArgs = PP.punctuate "," (PP.pretty <$> args) pretty (ETuple _ a as) = "(" <> PP.cat (PP.punctuate "," (PP.pretty <$> tupleItems a as)) <> ")" - where - tupleItems :: a -> NE.NonEmpty a -> [a] - tupleItems b bs = b : NE.toList bs + where + tupleItems :: a -> NE.NonEmpty a -> [a] + tupleItems b bs = b : NE.toList bs + pretty (EPatternMatch _ matchExp patterns) = + prettyPatternMatch matchExp patterns + +prettyPatternMatch :: + Expr ann -> + NE.NonEmpty (Pattern ann, Expr ann) -> + PP.Doc style +prettyPatternMatch sumExpr matches = + "match" + <+> printSubExpr sumExpr + <+> "with" + <+> PP.line + <> PP.indent + 2 + ( PP.align $ + PP.vsep + ( zipWith + (<+>) + (" " : repeat "|") + (printMatch <$> NE.toList matches) + ) + ) + where + printMatch (construct, expr') = + PP.pretty construct + <+> "->" + <+> PP.line + <> indentMulti 4 (printSubExpr expr') + +-- print simple things with no brackets, and complex things inside brackets +printSubExpr :: Expr ann -> PP.Doc style +printSubExpr expr = case expr of + all'@EIf {} -> inParens all' + all'@EApply {} -> inParens all' + all'@ETuple {} -> inParens all' + all'@EPatternMatch {} -> inParens all' + a -> PP.pretty a + +inParens :: Expr ann -> PP.Doc style +inParens = PP.parens . PP.pretty data Op = OpAdd diff --git a/llvm-calc4/src/Calc/Types/Pattern.hs b/llvm-calc4/src/Calc/Types/Pattern.hs new file mode 100644 index 00000000..9e65c789 --- /dev/null +++ b/llvm-calc4/src/Calc/Types/Pattern.hs @@ -0,0 +1,34 @@ +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE DeriveTraversable #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE OverloadedStrings #-} + +module Calc.Types.Pattern (Pattern (..)) where + +import Calc.Types.Identifier +import Calc.Types.Prim +import qualified Data.List.NonEmpty as NE +import GHC.Generics +import qualified Prettyprinter as PP + +data Pattern ann + = PWildcard ann + | PVar ann Identifier + | PTuple ann (Pattern ann) (NE.NonEmpty (Pattern ann)) + | PLiteral ann Prim + deriving stock + ( Eq, + Ord, + Show, + Functor, + Foldable, + Generic, + Traversable + ) + +instance PP.Pretty (Pattern ann) where + pretty (PWildcard _) = "_" + pretty (PVar _ a) = PP.pretty a + pretty (PLiteral _ lit) = PP.pretty lit + pretty (PTuple _ a as) = + "(" <> PP.hsep (PP.punctuate ", " (PP.pretty <$> ([a] <> NE.toList as))) <> ")" diff --git a/llvm-calc4/src/Calc/Types/Type.hs b/llvm-calc4/src/Calc/Types/Type.hs index a31d3ba8..724594b9 100644 --- a/llvm-calc4/src/Calc/Types/Type.hs +++ b/llvm-calc4/src/Calc/Types/Type.hs @@ -4,8 +4,8 @@ module Calc.Types.Type (Type (..), TypePrim (..)) where -import qualified Prettyprinter as PP import qualified Data.List.NonEmpty as NE +import qualified Prettyprinter as PP data TypePrim = TBool | TInt deriving stock (Eq, Ord, Show) @@ -28,10 +28,6 @@ instance PP.Pretty (Type ann) where prettyArgs = PP.concatWith (PP.surround PP.comma) (PP.pretty <$> args) pretty (TTuple _ a as) = "(" <> PP.cat (PP.punctuate "," (PP.pretty <$> tupleItems a as)) <> ")" - where - tupleItems :: a -> NE.NonEmpty a -> [a] - tupleItems b bs = b : NE.toList bs - - - - + where + tupleItems :: a -> NE.NonEmpty a -> [a] + tupleItems b bs = b : NE.toList bs diff --git a/llvm-calc4/test/Test/Parser/ParserSpec.hs b/llvm-calc4/test/Test/Parser/ParserSpec.hs index 23a29949..0ac5e3c8 100644 --- a/llvm-calc4/test/Test/Parser/ParserSpec.hs +++ b/llvm-calc4/test/Test/Parser/ParserSpec.hs @@ -1,11 +1,12 @@ +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedStrings #-} - {-# LANGUAGE LambdaCase #-} + module Test.Parser.ParserSpec (spec) where -import qualified Data.List.NonEmpty as NE import Calc import Data.Foldable (traverse_) import Data.Functor +import qualified Data.List.NonEmpty as NE import Data.String import qualified Data.Text as T import Test.Hspec @@ -24,6 +25,10 @@ tuple = \case (a : b : rest) -> ETuple mempty a (b NE.:| rest) _ -> error "not enough items for tuple" +patternMatch :: (Monoid ann) => Expr ann -> [(Pattern ann, Expr ann)] -> Expr ann +patternMatch matchExpr matches = + EPatternMatch mempty matchExpr (NE.fromList matches) + tyInt :: (Monoid ann) => Type ann tyInt = TPrim mempty TInt @@ -35,6 +40,14 @@ tyTuple = \case (a : b : rest) -> TTuple mempty a (b NE.:| rest) _ -> error "not enough items for tyTuple" +patTuple :: (Monoid ann) => [Pattern ann] -> Pattern ann +patTuple = \case + (a : b : rest) -> PTuple mempty a (b NE.:| rest) + _ -> error "not enough items for patTuple" + +patInt :: (Monoid ann) => Integer -> Pattern ann +patInt = PLiteral mempty . PInt + spec :: Spec spec = do describe "ParserSpec" $ do @@ -118,7 +131,14 @@ spec = do ("if True then 1 else 2", EIf () (bool True) (int 1) (int 2)), ("a + 1", EInfix () OpAdd (var "a") (int 1)), ("add(1,2)", EApply () "add" [int 1, int 2]), - ("go()", EApply () "go" []) + ("go()", EApply () "go" []), + ( "case (1,2,3) of (5,6,7) -> True | (1,2,3) -> False", + patternMatch + (tuple [int 1, int 2, int 3]) + [ (patTuple [patInt 5, patInt 6, patInt 7], bool True), + (patTuple [patInt 1, patInt 2, patInt 3], bool False) + ] + ) ] traverse_ ( \(str, expr) -> it (T.unpack str) $ do diff --git a/llvm-calc4/test/Test/Typecheck/TypecheckSpec.hs b/llvm-calc4/test/Test/Typecheck/TypecheckSpec.hs index f9374048..acfb9c2a 100644 --- a/llvm-calc4/test/Test/Typecheck/TypecheckSpec.hs +++ b/llvm-calc4/test/Test/Typecheck/TypecheckSpec.hs @@ -101,7 +101,9 @@ spec = do ("1 - 10", "Integer"), ("2 == 2", "Boolean"), ("if True then 1 else 2", "Integer"), - ("if False then True else False", "Boolean") + ("if False then True else False", "Boolean"), + ("(1,2,True)", "(Integer,Integer,Boolean)"), + ("case (1,2,3) of (a,b,_) -> a + b", "Integer") ] describe "Successfully typechecking expressions" $ do From 17412487747c30f4f940caeed4dccd864b2cf394 Mon Sep 17 00:00:00 2001 From: Daniel Harvey Date: Thu, 13 Apr 2023 09:01:08 +0100 Subject: [PATCH 04/13] Copy pasta that typechecker --- llvm-calc4/src/Calc/Parser/Pattern.hs | 11 ++-- llvm-calc4/src/Calc/PatternUtils.hs | 9 ++++ llvm-calc4/src/Calc/Typecheck/Elaborate.hs | 63 ++++++++++++++++++++-- llvm-calc4/src/Calc/Typecheck/Error.hs | 14 +++++ llvm-calc4/src/Calc/Utils.hs | 20 +++++++ 5 files changed, 106 insertions(+), 11 deletions(-) create mode 100644 llvm-calc4/src/Calc/PatternUtils.hs create mode 100644 llvm-calc4/src/Calc/Utils.hs diff --git a/llvm-calc4/src/Calc/Parser/Pattern.hs b/llvm-calc4/src/Calc/Parser/Pattern.hs index 043b8ae2..58b441e7 100644 --- a/llvm-calc4/src/Calc/Parser/Pattern.hs +++ b/llvm-calc4/src/Calc/Parser/Pattern.hs @@ -6,15 +6,14 @@ module Calc.Parser.Pattern ) where -import Calc.Parser.Primitives import Calc.Parser.Identifier -import qualified Data.List.NonEmpty as NE -import Calc.Types.Pattern +import Calc.Parser.Primitives import Calc.Parser.Shared +import Calc.Parser.Types import Calc.Types +import qualified Data.List.NonEmpty as NE import Text.Megaparsec import Text.Megaparsec.Char -import Calc.Parser.Types type ParserPattern = Pattern Annotation @@ -23,8 +22,7 @@ patternParser = label "pattern match" ( orInBrackets - ( - try patTupleParser + ( try patTupleParser <|> try patWildcardParser <|> try patVariableParser <|> patLitParser @@ -63,4 +61,3 @@ patTupleParser = label "tuple" $ patLitParser :: Parser ParserPattern patLitParser = withLocation PLiteral primParser - diff --git a/llvm-calc4/src/Calc/PatternUtils.hs b/llvm-calc4/src/Calc/PatternUtils.hs new file mode 100644 index 00000000..c6b64fe1 --- /dev/null +++ b/llvm-calc4/src/Calc/PatternUtils.hs @@ -0,0 +1,9 @@ +module Calc.PatternUtils (getPatternAnnotation) where + +import Calc.Types.Pattern + +getPatternAnnotation :: Pattern ann -> ann +getPatternAnnotation (PLiteral ann _) = ann +getPatternAnnotation (PWildcard ann) = ann +getPatternAnnotation (PVar ann _ ) = ann +getPatternAnnotation (PTuple ann _ _) = ann diff --git a/llvm-calc4/src/Calc/Typecheck/Elaborate.hs b/llvm-calc4/src/Calc/Typecheck/Elaborate.hs index 9c9223e8..3acf89fd 100644 --- a/llvm-calc4/src/Calc/Typecheck/Elaborate.hs +++ b/llvm-calc4/src/Calc/Typecheck/Elaborate.hs @@ -5,18 +5,25 @@ module Calc.Typecheck.Elaborate (elaborate, elaborateFunction, elaborateModule) where import Calc.ExprUtils +import Calc.PatternUtils import Calc.TypeUtils import Calc.Typecheck.Error import Calc.Typecheck.Types import Calc.Types.Expr import Calc.Types.Function +import Calc.Types.Identifier import Calc.Types.Module +import Calc.Types.Pattern import Calc.Types.Prim import Calc.Types.Type import Control.Monad (when, zipWithM) +import Calc.Utils import Control.Monad.Except import Data.Bifunctor (second) import Data.Functor +import qualified Data.List.NonEmpty as NE +import Data.Map.Strict (Map) +import qualified Data.Map.Strict as M elaborateModule :: forall ann. @@ -53,6 +60,34 @@ check ty expr = do then pure (expr $> ty) else throwError (TypeMismatch ty (getOuterAnnotation exprA)) +-- given the type of the expression in a pattern match, +-- check that the pattern makes sense with it +checkPattern :: + ( Show ann + ) => + Type ann -> + Pattern ann -> + TypecheckM + ann + ( Pattern (Type ann), + Map Identifier (Type ann) + ) +checkPattern checkTy checkPat = do + case (checkTy, checkPat) of + (TTuple _ tA tRest, PTuple ann pA pRest) | length tRest == length pRest -> do + (patA, envA) <- checkPattern tA pA + (patRest, envRest) <- neUnzip <$> neZipWithM checkPattern tRest pRest + let ty = TTuple ann (getPatternAnnotation patA) (getPatternAnnotation <$> patRest) + env = envA <> mconcat (NE.toList envRest) + pure (PTuple ty patA patRest, env) + (ty, PVar _ ident) -> + pure (PVar ty ident, M.singleton ident ty) + (ty, PWildcard _) -> pure (PWildcard ty, mempty) + (ty@(TPrim _ tPrim), PLiteral _ pPrim) + | tPrim == typePrimFromPrim pPrim -> + pure (PLiteral ty pPrim, mempty) + (otherTy, otherPat) -> throwError (PatternMismatch otherPat otherTy) + inferIf :: ann -> Expr ann -> @@ -128,8 +163,25 @@ infer (EPrim ann prim) = pure (EPrim (typeFromPrim ann prim) prim) infer (EIf ann predExpr thenExpr elseExpr) = inferIf ann predExpr thenExpr elseExpr -infer (ETuple {}) = error "infer ETuple" -infer (EPatternMatch {}) = error "infer EPatternMatch" +infer (ETuple ann fstExpr restExpr) = do + typedFst <- infer fstExpr + typedRest <- traverse infer restExpr + let typ = + TTuple + ann + (getOuterAnnotation typedFst) + (getOuterAnnotation <$> typedRest) + pure $ ETuple typ typedFst typedRest +infer (EPatternMatch ann matchExpr pats) = do + elabExpr <- infer matchExpr + let withPair (pat, patExpr) = do + (elabPat, newVars) <- checkPattern (getOuterAnnotation elabExpr) pat + elabPatExpr <- withNewVars newVars (infer patExpr) + pure (elabPat, elabPatExpr) + elabPats <- traverse withPair pats + let allTypes = getOuterAnnotation . snd <$> elabPats + typ <- combineMany allTypes + pure (EPatternMatch typ elabExpr elabPats) infer (EApply ann fnName args) = do fn <- lookupFunction ann fnName (ty, elabArgs) <- case fn of @@ -147,6 +199,9 @@ infer (EVar ann var) = do infer (EInfix ann op a b) = inferInfix ann op a b +typePrimFromPrim :: Prim -> TypePrim +typePrimFromPrim (PInt _) = TInt +typePrimFromPrim (PBool _) = TBool + typeFromPrim :: ann -> Prim -> Type ann -typeFromPrim ann (PInt _) = TPrim ann TInt -typeFromPrim ann (PBool _) = TPrim ann TBool +typeFromPrim ann prim = TPrim ann (typePrimFromPrim prim) diff --git a/llvm-calc4/src/Calc/Typecheck/Error.hs b/llvm-calc4/src/Calc/Typecheck/Error.hs index 55362c7a..1b0ab541 100644 --- a/llvm-calc4/src/Calc/Typecheck/Error.hs +++ b/llvm-calc4/src/Calc/Typecheck/Error.hs @@ -4,12 +4,14 @@ module Calc.Typecheck.Error (TypeError (..), typeErrorDiagnostic) where +import Calc.PatternUtils import Calc.SourceSpan import Calc.TypeUtils import Calc.Types.Annotation import Calc.Types.Expr import Calc.Types.FunctionName import Calc.Types.Identifier +import Calc.Types.Pattern import Calc.Types.Type import Data.HashSet (HashSet) import qualified Data.HashSet as HS @@ -28,6 +30,7 @@ data TypeError ann | FunctionNotFound ann FunctionName (HashSet FunctionName) | FunctionArgumentLengthMismatch ann Int Int -- expected, actual | NonFunctionTypeFound ann (Type ann) + | PatternMismatch (Pattern ann) (Type ann) deriving stock (Eq, Ord, Show) positionFromAnnotation :: @@ -180,6 +183,17 @@ typeErrorDiagnostic input e = ] ) [Diag.Note $ "Available in scope: " <> prettyPrint (prettyHashset existing)] + (PatternMismatch pat ty) -> + Diag.Err + Nothing + "Pattern mismatch!" + ( catMaybes + [ (,) + <$> positionFromAnnotation filename input (getPatternAnnotation pat) + <*> pure (Diag.This (prettyPrint $ "This should have type " <> PP.pretty ty)) + ] + ) + [] (FunctionNotFound ann fnName existing) -> Diag.Err Nothing diff --git a/llvm-calc4/src/Calc/Utils.hs b/llvm-calc4/src/Calc/Utils.hs new file mode 100644 index 00000000..ccb646ae --- /dev/null +++ b/llvm-calc4/src/Calc/Utils.hs @@ -0,0 +1,20 @@ +module Calc.Utils (neZipWithM, neUnzip) where + +-- useful junk goes here + +import qualified Data.List.NonEmpty as NE +import Data.Bifunctor +import Control.Monad (zipWithM) + +neZipWithM :: + (Applicative m) => + (a -> b -> m c) -> + NE.NonEmpty a -> + NE.NonEmpty b -> + m (NE.NonEmpty c) +neZipWithM f as bs = + NE.fromList <$> zipWithM f (NE.toList as) (NE.toList bs) + +neUnzip :: NE.NonEmpty (a, b) -> (NE.NonEmpty a, NE.NonEmpty b) +neUnzip = bimap NE.fromList NE.fromList . unzip . NE.toList + From 672edb70c90ec5e70a30ed5b505c1d84b8c82af9 Mon Sep 17 00:00:00 2001 From: Daniel Harvey Date: Sat, 15 Apr 2023 15:01:23 +0100 Subject: [PATCH 05/13] Pattern match typechecking --- llvm-calc4/src/Calc/Typecheck/Elaborate.hs | 25 ++++++---- llvm-calc4/src/Calc/Typecheck/Types.hs | 15 +++--- llvm-calc4/test/Test/Helpers.hs | 47 +++++++++++++++++++ llvm-calc4/test/Test/Parser/ParserSpec.hs | 40 +--------------- .../test/Test/Typecheck/TypecheckSpec.hs | 23 +++++---- 5 files changed, 87 insertions(+), 63 deletions(-) create mode 100644 llvm-calc4/test/Test/Helpers.hs diff --git a/llvm-calc4/src/Calc/Typecheck/Elaborate.hs b/llvm-calc4/src/Calc/Typecheck/Elaborate.hs index 3acf89fd..98210440 100644 --- a/llvm-calc4/src/Calc/Typecheck/Elaborate.hs +++ b/llvm-calc4/src/Calc/Typecheck/Elaborate.hs @@ -4,6 +4,7 @@ module Calc.Typecheck.Elaborate (elaborate, elaborateFunction, elaborateModule) where +import Data.Foldable (foldrM) import Calc.ExprUtils import Calc.PatternUtils import Calc.TypeUtils @@ -56,15 +57,23 @@ elaborate = runTypecheckM (TypecheckEnv mempty) . infer check :: Type ann -> Expr ann -> TypecheckM ann (Expr (Type ann)) check ty expr = do exprA <- infer expr - if void (getOuterAnnotation exprA) == void ty - then pure (expr $> ty) - else throwError (TypeMismatch ty (getOuterAnnotation exprA)) + _ <- checkTypeIsEqual ty (getOuterAnnotation exprA) + pure (expr $> ty) + +-- simple check for now +checkTypeIsEqual :: Type ann -> Type ann -> TypecheckM ann (Type ann) +checkTypeIsEqual tyA tyB + = if void tyA == void tyB + then pure tyA + else throwError (TypeMismatch tyA tyB) + +checkTypesAreEqual :: NE.NonEmpty (Type ann) -> TypecheckM ann (Type ann) +checkTypesAreEqual tys = + foldrM checkTypeIsEqual (NE.head tys) (NE.tail tys) -- given the type of the expression in a pattern match, -- check that the pattern makes sense with it checkPattern :: - ( Show ann - ) => Type ann -> Pattern ann -> TypecheckM @@ -172,15 +181,15 @@ infer (ETuple ann fstExpr restExpr) = do (getOuterAnnotation typedFst) (getOuterAnnotation <$> typedRest) pure $ ETuple typ typedFst typedRest -infer (EPatternMatch ann matchExpr pats) = do +infer (EPatternMatch _ann matchExpr pats) = do elabExpr <- infer matchExpr let withPair (pat, patExpr) = do (elabPat, newVars) <- checkPattern (getOuterAnnotation elabExpr) pat - elabPatExpr <- withNewVars newVars (infer patExpr) + elabPatExpr <- withVars (M.toList newVars) (infer patExpr) pure (elabPat, elabPatExpr) elabPats <- traverse withPair pats let allTypes = getOuterAnnotation . snd <$> elabPats - typ <- combineMany allTypes + typ <- checkTypesAreEqual allTypes pure (EPatternMatch typ elabExpr elabPats) infer (EApply ann fnName args) = do fn <- lookupFunction ann fnName diff --git a/llvm-calc4/src/Calc/Typecheck/Types.hs b/llvm-calc4/src/Calc/Typecheck/Types.hs index 7b69bd4e..44cc2ce1 100644 --- a/llvm-calc4/src/Calc/Typecheck/Types.hs +++ b/llvm-calc4/src/Calc/Typecheck/Types.hs @@ -7,6 +7,7 @@ module Calc.Typecheck.Types TypecheckEnv (..), lookupVar, withVar, + withVars, lookupFunction, withFunctionArgs, storeFunction, @@ -97,14 +98,16 @@ withVar identifier ty = } ) -withFunctionArgs :: [(ArgumentName, Type ann)] -> TypecheckM ann a -> TypecheckM ann a -withFunctionArgs args = +withVars :: [(Identifier, Type ann)] -> TypecheckM ann a -> TypecheckM ann a +withVars args = local ( \tce -> tce - { tceVars = tceVars tce <> HM.fromList tidiedArgs + { tceVars = tceVars tce <> HM.fromList args } ) - where - tidiedArgs = - fmap (first (\(ArgumentName arg) -> Identifier arg)) args + +withFunctionArgs :: [(ArgumentName, Type ann)] -> + TypecheckM ann a -> TypecheckM ann a +withFunctionArgs = withVars . + fmap (first (\(ArgumentName arg) -> Identifier arg)) diff --git a/llvm-calc4/test/Test/Helpers.hs b/llvm-calc4/test/Test/Helpers.hs new file mode 100644 index 00000000..d76a5943 --- /dev/null +++ b/llvm-calc4/test/Test/Helpers.hs @@ -0,0 +1,47 @@ + +{-# LANGUAGE LambdaCase #-} + +module Test.Helpers (int, bool, var, tuple, patternMatch, tyInt, + tyBool, tyTuple, patTuple,patInt) where + +import Calc +import qualified Data.List.NonEmpty as NE +import Data.String + +int :: (Monoid ann) => Integer -> Expr ann +int = EPrim mempty . PInt + +bool :: (Monoid ann) => Bool -> Expr ann +bool = EPrim mempty . PBool + +var :: (Monoid ann) => String -> Expr ann +var = EVar mempty . Identifier . fromString + +tuple :: (Monoid ann) => [Expr ann] -> Expr ann +tuple = \case + (a : b : rest) -> ETuple mempty a (b NE.:| rest) + _ -> error "not enough items for tuple" + +patternMatch :: (Monoid ann) => Expr ann -> [(Pattern ann, Expr ann)] -> Expr ann +patternMatch matchExpr matches = + EPatternMatch mempty matchExpr (NE.fromList matches) + +tyInt :: (Monoid ann) => Type ann +tyInt = TPrim mempty TInt + +tyBool :: (Monoid ann) => Type ann +tyBool = TPrim mempty TBool + +tyTuple :: (Monoid ann) => [Type ann] -> Type ann +tyTuple = \case + (a : b : rest) -> TTuple mempty a (b NE.:| rest) + _ -> error "not enough items for tyTuple" + +patTuple :: (Monoid ann) => [Pattern ann] -> Pattern ann +patTuple = \case + (a : b : rest) -> PTuple mempty a (b NE.:| rest) + _ -> error "not enough items for patTuple" + +patInt :: (Monoid ann) => Integer -> Pattern ann +patInt = PLiteral mempty . PInt + diff --git a/llvm-calc4/test/Test/Parser/ParserSpec.hs b/llvm-calc4/test/Test/Parser/ParserSpec.hs index 0ac5e3c8..04a020cc 100644 --- a/llvm-calc4/test/Test/Parser/ParserSpec.hs +++ b/llvm-calc4/test/Test/Parser/ParserSpec.hs @@ -3,51 +3,13 @@ module Test.Parser.ParserSpec (spec) where +import Test.Helpers import Calc import Data.Foldable (traverse_) import Data.Functor -import qualified Data.List.NonEmpty as NE -import Data.String import qualified Data.Text as T import Test.Hspec -int :: (Monoid ann) => Integer -> Expr ann -int = EPrim mempty . PInt - -bool :: (Monoid ann) => Bool -> Expr ann -bool = EPrim mempty . PBool - -var :: (Monoid ann) => String -> Expr ann -var = EVar mempty . Identifier . fromString - -tuple :: (Monoid ann) => [Expr ann] -> Expr ann -tuple = \case - (a : b : rest) -> ETuple mempty a (b NE.:| rest) - _ -> error "not enough items for tuple" - -patternMatch :: (Monoid ann) => Expr ann -> [(Pattern ann, Expr ann)] -> Expr ann -patternMatch matchExpr matches = - EPatternMatch mempty matchExpr (NE.fromList matches) - -tyInt :: (Monoid ann) => Type ann -tyInt = TPrim mempty TInt - -tyBool :: (Monoid ann) => Type ann -tyBool = TPrim mempty TBool - -tyTuple :: (Monoid ann) => [Type ann] -> Type ann -tyTuple = \case - (a : b : rest) -> TTuple mempty a (b NE.:| rest) - _ -> error "not enough items for tyTuple" - -patTuple :: (Monoid ann) => [Pattern ann] -> Pattern ann -patTuple = \case - (a : b : rest) -> PTuple mempty a (b NE.:| rest) - _ -> error "not enough items for patTuple" - -patInt :: (Monoid ann) => Integer -> Pattern ann -patInt = PLiteral mempty . PInt - spec :: Spec spec = do describe "ParserSpec" $ do diff --git a/llvm-calc4/test/Test/Typecheck/TypecheckSpec.hs b/llvm-calc4/test/Test/Typecheck/TypecheckSpec.hs index acfb9c2a..0be3aa1a 100644 --- a/llvm-calc4/test/Test/Typecheck/TypecheckSpec.hs +++ b/llvm-calc4/test/Test/Typecheck/TypecheckSpec.hs @@ -2,6 +2,7 @@ module Test.Typecheck.TypecheckSpec (spec) where +import Test.Helpers import Calc.ExprUtils import Calc.Parser import Calc.Typecheck.Elaborate @@ -68,9 +69,9 @@ spec = do describe "TypecheckSpec" $ do describe "Function" $ do let succeeding = - [ ("function one () { 1 }", TFunction () [] (TPrim () TInt)), + [ ("function one () { 1 }", TFunction () [] tyInt), ( "function not (bool: Boolean) { if bool then False else True }", - TFunction () [TPrim () TBool] (TPrim () TBool) + TFunction () [tyBool ] tyBool ) ] @@ -79,8 +80,8 @@ spec = do describe "Module" $ do let succeeding = - [ ("function ignore() { 1 } 42", TPrim () TInt), - ("function increment(a: Integer) { a + 1 } increment(41)", TPrim () TInt), + [ ("function ignore() { 1 } 42", tyInt ), + ("function increment(a: Integer) { a + 1 } increment(41)", tyInt ), ("function inc(a: Integer) { a + 1 } function inc2(a: Integer) { inc(a) } inc2(41)", TPrim () TInt) ] describe "Successfully typechecking modules" $ do @@ -103,19 +104,21 @@ spec = do ("if True then 1 else 2", "Integer"), ("if False then True else False", "Boolean"), ("(1,2,True)", "(Integer,Integer,Boolean)"), - ("case (1,2,3) of (a,b,_) -> a + b", "Integer") + ("case (1,2,3) of (a,b,_) -> a + b", "Integer"), + ("case (1,True) of (2,b) -> b | _ -> False", "Boolean") ] describe "Successfully typechecking expressions" $ do traverse_ testTypecheck succeeding let failing = - [ ("if 1 then 1 else 2", PredicateIsNotBoolean () (TPrim () TInt)), - ("if True then 1 else True", TypeMismatch (TPrim () TInt) (TPrim () TBool)), - ("1 + True", InfixTypeMismatch OpAdd [(TPrim () TInt, TPrim () TBool)]), - ("True + False", InfixTypeMismatch OpAdd [(TPrim () TInt, TPrim () TBool), (TPrim () TInt, TPrim () TBool)]), + [ ("if 1 then 1 else 2", PredicateIsNotBoolean () tyInt ), + ("if True then 1 else True", TypeMismatch tyInt tyBool), + ("1 + True", InfixTypeMismatch OpAdd [(tyInt, tyBool)]), + ("True + False", InfixTypeMismatch OpAdd [(tyInt, tyBool), (tyInt,tyBool)]), ("1 * False", InfixTypeMismatch OpMultiply [(TPrim () TInt, TPrim () TBool)]), - ("True - 1", InfixTypeMismatch OpSubtract [(TPrim () TInt, TPrim () TBool)]) + ("True - 1", InfixTypeMismatch OpSubtract [(TPrim () TInt, TPrim () TBool)]), + ("case (1,True) of (a, False) -> a | (_,c) -> c", TypeMismatch tyBool tyInt) ] describe "Failing typechecking expressions" $ do From 2239f43dc542d5d842d7ffa58c7c997667b033e9 Mon Sep 17 00:00:00 2001 From: Daniel Harvey Date: Mon, 24 Apr 2023 18:17:41 +0100 Subject: [PATCH 06/13] Error fixes --- llvm-calc4/src/Calc/Typecheck/Error.hs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/llvm-calc4/src/Calc/Typecheck/Error.hs b/llvm-calc4/src/Calc/Typecheck/Error.hs index 1b0ab541..328610bb 100644 --- a/llvm-calc4/src/Calc/Typecheck/Error.hs +++ b/llvm-calc4/src/Calc/Typecheck/Error.hs @@ -31,6 +31,8 @@ data TypeError ann | FunctionArgumentLengthMismatch ann Int Int -- expected, actual | NonFunctionTypeFound ann (Type ann) | PatternMismatch (Pattern ann) (Type ann) + | FunctionArgumentLengthMismatch ann Int Int -- expected, actual + | NonFunctionTypeFound ann (Type ann) deriving stock (Eq, Ord, Show) positionFromAnnotation :: From bde0bb05a5b3fa684872d0d7d6226706e40d171d Mon Sep 17 00:00:00 2001 From: Daniel Harvey Date: Mon, 1 May 2023 21:02:45 +0100 Subject: [PATCH 07/13] Fixed --- llvm-calc4/src/Calc/Typecheck/Error.hs | 2 -- 1 file changed, 2 deletions(-) diff --git a/llvm-calc4/src/Calc/Typecheck/Error.hs b/llvm-calc4/src/Calc/Typecheck/Error.hs index 328610bb..5c5887b1 100644 --- a/llvm-calc4/src/Calc/Typecheck/Error.hs +++ b/llvm-calc4/src/Calc/Typecheck/Error.hs @@ -28,8 +28,6 @@ data TypeError ann | TypeMismatch (Type ann) (Type ann) | VarNotFound ann Identifier (HashSet Identifier) | FunctionNotFound ann FunctionName (HashSet FunctionName) - | FunctionArgumentLengthMismatch ann Int Int -- expected, actual - | NonFunctionTypeFound ann (Type ann) | PatternMismatch (Pattern ann) (Type ann) | FunctionArgumentLengthMismatch ann Int Int -- expected, actual | NonFunctionTypeFound ann (Type ann) From 8810e163d0e756826dacd21b1f275411026ac62a Mon Sep 17 00:00:00 2001 From: Daniel Harvey Date: Tue, 23 May 2023 21:53:44 +0100 Subject: [PATCH 08/13] Interpreting pattern matches --- llvm-calc4/src/Calc/Interpreter.hs | 50 +++++++++++++++++-- .../test/Test/Interpreter/InterpreterSpec.hs | 6 ++- 2 files changed, 50 insertions(+), 6 deletions(-) diff --git a/llvm-calc4/src/Calc/Interpreter.hs b/llvm-calc4/src/Calc/Interpreter.hs index e9ec956f..9f99f9e6 100644 --- a/llvm-calc4/src/Calc/Interpreter.hs +++ b/llvm-calc4/src/Calc/Interpreter.hs @@ -13,6 +13,8 @@ module Calc.Interpreter ) where +import Data.Monoid (First(..)) +import qualified Data.List.NonEmpty as NE import Calc.Types import Control.Monad.Except import Control.Monad.Reader @@ -62,12 +64,12 @@ runInterpreter = flip evalStateT initialState . flip runReaderT initialEnv . run -- we use the Reader env here because the vars disappear after we use them, -- say, in a function withVars :: - [(ArgumentName, b)] -> + [ArgumentName] -> [Expr ann] -> InterpretM ann a -> InterpretM ann a withVars fnArgs inputs = - let newVars = M.fromList $ zip (coerce . fst <$> fnArgs) inputs + let newVars = M.fromList $ zip (coerce <$> fnArgs) inputs in local ( \(InterpreterEnv ieVars) -> InterpreterEnv $ ieVars <> newVars @@ -109,7 +111,7 @@ interpretApply fnName args = do fn <- gets (M.lookup fnName . isFunctions) case fn of Just (Function {fnArgs, fnBody}) -> - withVars fnArgs args (interpret fnBody) + withVars (fst <$> fnArgs) args (interpret fnBody) Nothing -> do allFnNames <- gets (M.keys . isFunctions) throwError (FunctionNotFound fnName allFnNames) @@ -127,7 +129,9 @@ interpret (EApply _ fnName args) = interpret (EInfix ann op a b) = interpretInfix ann op a b interpret (ETuple {}) = error "interpret ETuple" -interpret (EPatternMatch {}) = error "interpret EPatternMatch" +interpret (EPatternMatch _ expr pats) = do + exprA <- interpret expr + interpretPatternMatch exprA pats interpret (EIf ann predExpr thenExpr elseExpr) = do predA <- interpret predExpr case predA of @@ -135,6 +139,44 @@ interpret (EIf ann predExpr thenExpr elseExpr) = do (EPrim _ (PBool False)) -> interpret elseExpr other -> throwError (NonBooleanPredicate ann other) +interpretPatternMatch :: + Expr ann -> + NE.NonEmpty (Pattern ann, Expr ann) -> + InterpretM ann (Expr ann) +interpretPatternMatch expr' patterns = do + -- interpret match expression + intExpr <- interpret expr' + let foldF (pat, patExpr) = case patternMatches pat intExpr of + Just bindings -> First (Just (patExpr, bindings)) + _ -> First Nothing + + -- get first matching pattern + case getFirst (foldMap foldF patterns) of + Just (patExpr, bindings) -> + let vars = fmap (coerce . fst) bindings + exprs = fmap snd bindings + in withVars vars exprs (interpret patExpr) + _ -> + error "pattern match failure" + +-- pull vars out of expr to match patterns +patternMatches :: + Pattern ann -> + Expr ann -> + Maybe [(Identifier, Expr ann)] +patternMatches (PWildcard _) _ = pure [] +patternMatches (PVar _ name) expr = pure [(name, expr)] +patternMatches (PTuple _ pA pAs) (ETuple _ a as) = do + matchA <- patternMatches pA a + matchAs <- + traverse + (uncurry patternMatches) + (zip (NE.toList pAs) (NE.toList as)) + pure $ matchA <> mconcat matchAs +patternMatches (PLiteral _ pB) (EPrim _ b) + | pB == b = pure mempty +patternMatches _ _ = Nothing + interpretModule :: Module ann -> InterpretM ann (Expr ann) diff --git a/llvm-calc4/test/Test/Interpreter/InterpreterSpec.hs b/llvm-calc4/test/Test/Interpreter/InterpreterSpec.hs index fe64a5fd..0d996fa7 100644 --- a/llvm-calc4/test/Test/Interpreter/InterpreterSpec.hs +++ b/llvm-calc4/test/Test/Interpreter/InterpreterSpec.hs @@ -43,7 +43,8 @@ spec = do describe "Modules" $ do let cases = [ ("1 + 1", "2"), - ("function increment(a: Integer) { a + 1 } increment(-11)", "-10") + ("function increment(a: Integer) { a + 1 } increment(-11)", "-10"), + ("function swap(pair: (Integer,Boolean)) { case pair of (a,b) -> (b,a) } swap(1,True)", "(True, 1)") ] traverse_ ( \(input, expect) -> @@ -61,7 +62,8 @@ spec = do ("1 + 1 == 2", "True"), ("2 + 2 == 5", "False"), ("if False then True else False", "False"), - ("if 1 == 1 then 6 else 5", "6") + ("if 1 == 1 then 6 else 5", "6"), + ("case (1, True) of (a,False) -> a | (_,c) -> c", "True") ] traverse_ ( \(input, expect) -> From 6fe30d73d9f984378051521fb5a15e3cf5903fff Mon Sep 17 00:00:00 2001 From: Daniel Harvey Date: Thu, 25 May 2023 13:32:09 +0100 Subject: [PATCH 09/13] OH --- llvm-calc4/src/Calc/Interpreter.hs | 19 +++--- llvm-calc4/src/Calc/Parser/Pattern.hs | 2 +- llvm-calc4/src/Calc/PatternUtils.hs | 2 +- llvm-calc4/src/Calc/Patterns/Flatten.hs | 67 +++++++++++++++++++ llvm-calc4/src/Calc/Typecheck/Elaborate.hs | 12 ++-- llvm-calc4/src/Calc/Typecheck/Types.hs | 11 +-- llvm-calc4/src/Calc/Utils.hs | 5 +- llvm-calc4/test/Main.hs | 4 +- llvm-calc4/test/Test/Helpers.hs | 17 +++-- .../test/Test/Interpreter/InterpreterSpec.hs | 2 +- llvm-calc4/test/Test/Parser/ParserSpec.hs | 9 ++- llvm-calc4/test/Test/Patterns/PatternsSpec.hs | 36 ++++++++++ .../test/Test/Typecheck/TypecheckSpec.hs | 12 ++-- 13 files changed, 162 insertions(+), 36 deletions(-) create mode 100644 llvm-calc4/src/Calc/Patterns/Flatten.hs create mode 100644 llvm-calc4/test/Test/Patterns/PatternsSpec.hs diff --git a/llvm-calc4/src/Calc/Interpreter.hs b/llvm-calc4/src/Calc/Interpreter.hs index 9f99f9e6..f52bf95f 100644 --- a/llvm-calc4/src/Calc/Interpreter.hs +++ b/llvm-calc4/src/Calc/Interpreter.hs @@ -13,15 +13,15 @@ module Calc.Interpreter ) where -import Data.Monoid (First(..)) -import qualified Data.List.NonEmpty as NE import Calc.Types import Control.Monad.Except import Control.Monad.Reader import Control.Monad.State import Data.Coerce +import qualified Data.List.NonEmpty as NE import Data.Map.Strict (Map) import qualified Data.Map.Strict as M +import Data.Monoid (First (..)) -- | type for interpreter state newtype InterpreterState ann = InterpreterState @@ -33,6 +33,7 @@ data InterpreterError ann = NonBooleanPredicate ann (Expr ann) | FunctionNotFound FunctionName [FunctionName] | VarNotFound Identifier [Identifier] + | NoPatternsMatched (Expr ann) (NE.NonEmpty (Pattern ann)) deriving stock (Eq, Ord, Show) -- | type of Reader env for interpreter state @@ -128,7 +129,10 @@ interpret (EApply _ fnName args) = interpretApply fnName args interpret (EInfix ann op a b) = interpretInfix ann op a b -interpret (ETuple {}) = error "interpret ETuple" +interpret (ETuple ann a as) = do + aA <- interpret a + asA <- traverse interpret as + pure (ETuple ann aA asA) interpret (EPatternMatch _ expr pats) = do exprA <- interpret expr interpretPatternMatch exprA pats @@ -153,11 +157,10 @@ interpretPatternMatch expr' patterns = do -- get first matching pattern case getFirst (foldMap foldF patterns) of Just (patExpr, bindings) -> - let vars = fmap (coerce . fst) bindings - exprs = fmap snd bindings - in withVars vars exprs (interpret patExpr) - _ -> - error "pattern match failure" + let vars = fmap (coerce . fst) bindings + exprs = fmap snd bindings + in withVars vars exprs (interpret patExpr) + _ -> throwError (NoPatternsMatched expr' (fst <$> patterns)) -- pull vars out of expr to match patterns patternMatches :: diff --git a/llvm-calc4/src/Calc/Parser/Pattern.hs b/llvm-calc4/src/Calc/Parser/Pattern.hs index 58b441e7..a959ce8f 100644 --- a/llvm-calc4/src/Calc/Parser/Pattern.hs +++ b/llvm-calc4/src/Calc/Parser/Pattern.hs @@ -60,4 +60,4 @@ patTupleParser = label "tuple" $ ---- patLitParser :: Parser ParserPattern -patLitParser = withLocation PLiteral primParser +patLitParser = myLexeme $ withLocation PLiteral primParser diff --git a/llvm-calc4/src/Calc/PatternUtils.hs b/llvm-calc4/src/Calc/PatternUtils.hs index c6b64fe1..2433d884 100644 --- a/llvm-calc4/src/Calc/PatternUtils.hs +++ b/llvm-calc4/src/Calc/PatternUtils.hs @@ -5,5 +5,5 @@ import Calc.Types.Pattern getPatternAnnotation :: Pattern ann -> ann getPatternAnnotation (PLiteral ann _) = ann getPatternAnnotation (PWildcard ann) = ann -getPatternAnnotation (PVar ann _ ) = ann +getPatternAnnotation (PVar ann _) = ann getPatternAnnotation (PTuple ann _ _) = ann diff --git a/llvm-calc4/src/Calc/Patterns/Flatten.hs b/llvm-calc4/src/Calc/Patterns/Flatten.hs new file mode 100644 index 00000000..a3e46a2b --- /dev/null +++ b/llvm-calc4/src/Calc/Patterns/Flatten.hs @@ -0,0 +1,67 @@ +{-# LANGUAGE DerivingStrategies #-} + +module Calc.Patterns.Flatten (SimpleExpr(..), + SimplePattern(..), flattenPatterns) where + +import Data.Bifunctor (first) +import Control.Monad (void) +import Calc.Types +import Data.List.NonEmpty as NE + +-- we wish to flatten patterns like this +-- +-- case p of (1,2) -> True | (_,_) -> False +-- +-- becomes +-- +-- case p of +-- [1,b] -> case b of +-- 2 -> True +-- _ -> False +-- _ -> False + +-- case p of (1,2,3) -> True | (_,_,_) -> False +-- +-- becomes +-- +-- case p of +-- [1, b, c] -> case b of +-- 2 -> case c of +-- 3 -> True +-- _ -> False +-- _ -> False +-- _ -> False + +-- we'll identify missing patterns when we're unable to fill in the fallthrough +-- cases + +-- is this it? +data SimplePattern + = SPTuple Int + | SPPrim Prim + | SPWildcard + deriving stock (Eq,Ord,Show) + +data SimpleExpr + = SEPrim Prim + | SETupleItem Int SimpleExpr + | SEPatternMatch SimpleExpr [(SimplePattern, SimpleExpr)] + deriving stock (Eq,Ord,Show) + +-- | this does not bind variables yet +flattenPatterns :: NE.NonEmpty (Pattern ann, Expr ann) -> [(SimplePattern, SimpleExpr )] +flattenPatterns = + fmap (uncurry flattenPattern . first unrollTuples) . NE.toList + +unrollTuples :: Pattern ann -> [Pattern ann] +unrollTuples (PTuple _ a as) = [a] <> NE.toList as +unrollTuples other = [other] + +flattenPattern :: [Pattern ann] -> Expr ann -> (SimplePattern, SimpleExpr) +flattenPattern [PWildcard _] expr = (SPWildcard, simpleExpr expr) +flattenPattern [PLiteral _ prim] expr = (SPPrim prim, simpleExpr expr) +flattenPattern other _ = error $ "flattenPattern " <> show (fmap void other) + +simpleExpr :: Expr ann -> SimpleExpr +simpleExpr (EPrim _ prim) = SEPrim prim +simpleExpr other = error $ "simpleExpr " <> show (void other) diff --git a/llvm-calc4/src/Calc/Typecheck/Elaborate.hs b/llvm-calc4/src/Calc/Typecheck/Elaborate.hs index 98210440..c34be1ae 100644 --- a/llvm-calc4/src/Calc/Typecheck/Elaborate.hs +++ b/llvm-calc4/src/Calc/Typecheck/Elaborate.hs @@ -4,7 +4,6 @@ module Calc.Typecheck.Elaborate (elaborate, elaborateFunction, elaborateModule) where -import Data.Foldable (foldrM) import Calc.ExprUtils import Calc.PatternUtils import Calc.TypeUtils @@ -17,10 +16,11 @@ import Calc.Types.Module import Calc.Types.Pattern import Calc.Types.Prim import Calc.Types.Type -import Control.Monad (when, zipWithM) import Calc.Utils +import Control.Monad (when, zipWithM) import Control.Monad.Except import Data.Bifunctor (second) +import Data.Foldable (foldrM) import Data.Functor import qualified Data.List.NonEmpty as NE import Data.Map.Strict (Map) @@ -62,10 +62,10 @@ check ty expr = do -- simple check for now checkTypeIsEqual :: Type ann -> Type ann -> TypecheckM ann (Type ann) -checkTypeIsEqual tyA tyB - = if void tyA == void tyB - then pure tyA - else throwError (TypeMismatch tyA tyB) +checkTypeIsEqual tyA tyB = + if void tyA == void tyB + then pure tyA + else throwError (TypeMismatch tyA tyB) checkTypesAreEqual :: NE.NonEmpty (Type ann) -> TypecheckM ann (Type ann) checkTypesAreEqual tys = diff --git a/llvm-calc4/src/Calc/Typecheck/Types.hs b/llvm-calc4/src/Calc/Typecheck/Types.hs index 44cc2ce1..c28cdcf2 100644 --- a/llvm-calc4/src/Calc/Typecheck/Types.hs +++ b/llvm-calc4/src/Calc/Typecheck/Types.hs @@ -107,7 +107,10 @@ withVars args = } ) -withFunctionArgs :: [(ArgumentName, Type ann)] -> - TypecheckM ann a -> TypecheckM ann a -withFunctionArgs = withVars . - fmap (first (\(ArgumentName arg) -> Identifier arg)) +withFunctionArgs :: + [(ArgumentName, Type ann)] -> + TypecheckM ann a -> + TypecheckM ann a +withFunctionArgs = + withVars + . fmap (first (\(ArgumentName arg) -> Identifier arg)) diff --git a/llvm-calc4/src/Calc/Utils.hs b/llvm-calc4/src/Calc/Utils.hs index ccb646ae..b90efab9 100644 --- a/llvm-calc4/src/Calc/Utils.hs +++ b/llvm-calc4/src/Calc/Utils.hs @@ -2,9 +2,9 @@ module Calc.Utils (neZipWithM, neUnzip) where -- useful junk goes here -import qualified Data.List.NonEmpty as NE -import Data.Bifunctor import Control.Monad (zipWithM) +import Data.Bifunctor +import qualified Data.List.NonEmpty as NE neZipWithM :: (Applicative m) => @@ -17,4 +17,3 @@ neZipWithM f as bs = neUnzip :: NE.NonEmpty (a, b) -> (NE.NonEmpty a, NE.NonEmpty b) neUnzip = bimap NE.fromList NE.fromList . unzip . NE.toList - diff --git a/llvm-calc4/test/Main.hs b/llvm-calc4/test/Main.hs index 9ba43c30..0755441a 100644 --- a/llvm-calc4/test/Main.hs +++ b/llvm-calc4/test/Main.hs @@ -4,11 +4,13 @@ import Test.Hspec import qualified Test.Interpreter.InterpreterSpec import qualified Test.LLVM.LLVMSpec import qualified Test.Parser.ParserSpec +import qualified Test.Patterns.PatternsSpec import qualified Test.Typecheck.TypecheckSpec main :: IO () main = hspec $ do Test.Parser.ParserSpec.spec Test.Interpreter.InterpreterSpec.spec - Test.LLVM.LLVMSpec.spec Test.Typecheck.TypecheckSpec.spec + Test.Patterns.PatternsSpec.spec + Test.LLVM.LLVMSpec.spec diff --git a/llvm-calc4/test/Test/Helpers.hs b/llvm-calc4/test/Test/Helpers.hs index d76a5943..a0517d67 100644 --- a/llvm-calc4/test/Test/Helpers.hs +++ b/llvm-calc4/test/Test/Helpers.hs @@ -1,8 +1,18 @@ - {-# LANGUAGE LambdaCase #-} -module Test.Helpers (int, bool, var, tuple, patternMatch, tyInt, - tyBool, tyTuple, patTuple,patInt) where +module Test.Helpers + ( int, + bool, + var, + tuple, + patternMatch, + tyInt, + tyBool, + tyTuple, + patTuple, + patInt, + ) +where import Calc import qualified Data.List.NonEmpty as NE @@ -44,4 +54,3 @@ patTuple = \case patInt :: (Monoid ann) => Integer -> Pattern ann patInt = PLiteral mempty . PInt - diff --git a/llvm-calc4/test/Test/Interpreter/InterpreterSpec.hs b/llvm-calc4/test/Test/Interpreter/InterpreterSpec.hs index 0d996fa7..80b71743 100644 --- a/llvm-calc4/test/Test/Interpreter/InterpreterSpec.hs +++ b/llvm-calc4/test/Test/Interpreter/InterpreterSpec.hs @@ -44,7 +44,7 @@ spec = do let cases = [ ("1 + 1", "2"), ("function increment(a: Integer) { a + 1 } increment(-11)", "-10"), - ("function swap(pair: (Integer,Boolean)) { case pair of (a,b) -> (b,a) } swap(1,True)", "(True, 1)") + ("function swap(pair: (Integer,Boolean)) { case pair of (a,b) -> (b,a) } swap((1,True))", "(True, 1)") ] traverse_ ( \(input, expect) -> diff --git a/llvm-calc4/test/Test/Parser/ParserSpec.hs b/llvm-calc4/test/Test/Parser/ParserSpec.hs index 04a020cc..85cc5487 100644 --- a/llvm-calc4/test/Test/Parser/ParserSpec.hs +++ b/llvm-calc4/test/Test/Parser/ParserSpec.hs @@ -3,11 +3,11 @@ module Test.Parser.ParserSpec (spec) where -import Test.Helpers import Calc import Data.Foldable (traverse_) import Data.Functor import qualified Data.Text as T +import Test.Helpers import Test.Hspec spec :: Spec @@ -100,6 +100,13 @@ spec = do [ (patTuple [patInt 5, patInt 6, patInt 7], bool True), (patTuple [patInt 1, patInt 2, patInt 3], bool False) ] + ), + ( "case a of 100 -> True | _ -> False", + patternMatch + (var "a") + [ (patInt 100, bool True), + (PWildcard (), bool False) + ] ) ] traverse_ diff --git a/llvm-calc4/test/Test/Patterns/PatternsSpec.hs b/llvm-calc4/test/Test/Patterns/PatternsSpec.hs new file mode 100644 index 00000000..dc1b2d75 --- /dev/null +++ b/llvm-calc4/test/Test/Patterns/PatternsSpec.hs @@ -0,0 +1,36 @@ +{-# LANGUAGE OverloadedStrings #-} + +module Test.Patterns.PatternsSpec (spec) where + +import Calc +import Control.Monad (void) +import Data.Functor (($>)) +import Data.Text (Text) +import Test.Hspec +import qualified Data.List.NonEmpty as NE +import Data.Bifunctor (bimap) +import Calc.Patterns.Flatten (flattenPatterns, SimpleExpr(..),SimplePattern(..)) + +-- | try parsing the input, exploding if it's invalid +unsafeParsePattern :: Text -> (Expr (), NE.NonEmpty (Pattern (), Expr ())) +unsafeParsePattern input = case parseExprAndFormatError input of + Right (EPatternMatch _ expr pats) -> + (expr $> (), fmap (bimap void void) pats ) + Right other -> error $ "expected pattern match, got " <> show other + Left e -> error (show e) + +spec :: Spec +spec = do + describe "PatternsSpec" $ do + it "Trivial wildcard pattern converts without trouble" $ do + let (_expr,pats) = unsafeParsePattern "case a of _ -> 2" + flattenPatterns pats `shouldBe` [(SPWildcard, SEPrim (PInt 2))] + it "Trivial primitive pattern converts without trouble" $ do + let (_expr,pats) = unsafeParsePattern "case a of 1 -> 2 | _ -> 0" + flattenPatterns pats `shouldBe` [(SPPrim (PInt 1), SEPrim (PInt 2)), + (SPWildcard, SEPrim (PInt 0))] + it "Tuple is split into two patterns" $ do + let (_expr,pats) = unsafeParsePattern "case p of (1,2) -> True | (_,_) -> False" + flattenPatterns pats `shouldBe` [(SPPrim (PInt 1), SEPrim (PInt 2)), + (SPWildcard, SEPrim (PInt 0))] + diff --git a/llvm-calc4/test/Test/Typecheck/TypecheckSpec.hs b/llvm-calc4/test/Test/Typecheck/TypecheckSpec.hs index 0be3aa1a..de13366a 100644 --- a/llvm-calc4/test/Test/Typecheck/TypecheckSpec.hs +++ b/llvm-calc4/test/Test/Typecheck/TypecheckSpec.hs @@ -2,7 +2,6 @@ module Test.Typecheck.TypecheckSpec (spec) where -import Test.Helpers import Calc.ExprUtils import Calc.Parser import Calc.Typecheck.Elaborate @@ -16,6 +15,7 @@ import Control.Monad import Data.Either (isLeft) import Data.Foldable (traverse_) import Data.Text (Text) +import Test.Helpers import Test.Hspec runTC :: TypecheckM ann a -> Either (TypeError ann) a @@ -71,7 +71,7 @@ spec = do let succeeding = [ ("function one () { 1 }", TFunction () [] tyInt), ( "function not (bool: Boolean) { if bool then False else True }", - TFunction () [tyBool ] tyBool + TFunction () [tyBool] tyBool ) ] @@ -80,8 +80,8 @@ spec = do describe "Module" $ do let succeeding = - [ ("function ignore() { 1 } 42", tyInt ), - ("function increment(a: Integer) { a + 1 } increment(41)", tyInt ), + [ ("function ignore() { 1 } 42", tyInt), + ("function increment(a: Integer) { a + 1 } increment(41)", tyInt), ("function inc(a: Integer) { a + 1 } function inc2(a: Integer) { inc(a) } inc2(41)", TPrim () TInt) ] describe "Successfully typechecking modules" $ do @@ -112,10 +112,10 @@ spec = do traverse_ testTypecheck succeeding let failing = - [ ("if 1 then 1 else 2", PredicateIsNotBoolean () tyInt ), + [ ("if 1 then 1 else 2", PredicateIsNotBoolean () tyInt), ("if True then 1 else True", TypeMismatch tyInt tyBool), ("1 + True", InfixTypeMismatch OpAdd [(tyInt, tyBool)]), - ("True + False", InfixTypeMismatch OpAdd [(tyInt, tyBool), (tyInt,tyBool)]), + ("True + False", InfixTypeMismatch OpAdd [(tyInt, tyBool), (tyInt, tyBool)]), ("1 * False", InfixTypeMismatch OpMultiply [(TPrim () TInt, TPrim () TBool)]), ("True - 1", InfixTypeMismatch OpSubtract [(TPrim () TInt, TPrim () TBool)]), ("case (1,True) of (a, False) -> a | (_,c) -> c", TypeMismatch tyBool tyInt) From c6b962a884caf2448689c01f77cc652f24fcd6db Mon Sep 17 00:00:00 2001 From: Daniel Harvey Date: Thu, 25 May 2023 16:50:45 +0100 Subject: [PATCH 10/13] Quite good --- llvm-calc4/a.out | Bin 0 -> 51200 bytes llvm-calc4/src/Calc/Patterns/Flatten.hs | 133 ++++++++++-------- llvm-calc4/test/Test/Helpers.hs | 4 + llvm-calc4/test/Test/Patterns/PatternsSpec.hs | 57 ++++++-- 4 files changed, 127 insertions(+), 67 deletions(-) create mode 100755 llvm-calc4/a.out diff --git a/llvm-calc4/a.out b/llvm-calc4/a.out new file mode 100755 index 0000000000000000000000000000000000000000..b31b2ecaeaba0cc7253cc336d693c087080102ce GIT binary patch literal 51200 zcmeI*ZERCj7zgmvc5j3+H+*3#x(ZCd&~+8IK|maAgMqlgSj>D$&f4{E?daS3vaO7b zCJ4c3;xGn?35c2?izKpgo@e3rz|C5~F z^W1apbI<+tu6#Or?&cr2a)mGn3Ny9o)Dl@jY!_DQ3Gq0!T53|RU$MI8wVH-{KARfq z#i^(6$hgitx`C99HBF7F!|L^wspo{AXVPhIv?Mj8j4ItxJ)uK-y}{+$!x&RI=jWP; z43U`VQ%A~3H0Fv-oJy}(`l42kjMU&fxn7Q5&luP0Nsrh%`u7jbS$E|Ave_hc_MDXUjBu8>vrbsM#}rBZ)M>8M89k!y0} zcPXEmlx;!d9iVj5dO1AEQJK0fe&3ScYhKrH)Ug1jl>Fdi$XzMfTaegoVS}G;wx;1q#)l{$1 z^o1Bsdrr#L&53bCLZsnoab9dglMMY zRBG#~aTxb7eh-ozv>CTQ^XdA#JtzO2SgG;ft34J8+kIY_-5vM2Dd%{enw1*gOS}72 zwXfB3nxCVy{C*iRkDim)eVTq}l7s6nrp9BuPSdhnTDIx-J?Xw1CTfVT| zW20i;mSCXMR#vJ$o{=ZtXC|$?TfbjY?YOVIrOeq;R^Hy>b}Z@iyE@uJZLWBGpxxiK zq$BROIYS|zVpIR%yMp0OwklfDs$eWYZ=p9BusIwhr6r~IbzWE48IIdm@ec#foA`l! zmhQ=ZJBPLJdA^?W_s;PYH6BmU_s9PZAs_$&2tWV=5P$##AOHafKmY;|fB*y_009U< z00Izz00bZa0SG_<0uX=z1Rwwb2tWV=5P$##AOHafKmY;|fB*y_009U<00I#Bj|pUm zf?Y)KX zewZu`&fA&>E&DFYH(kT>b(5@ zrhWgkeb$(XBLpA-0SG_<0uX=z1Rwwb2tWV=5P$##AOHafKmY;|fB*y_009U<00Izz z00bZa0SG_<0uX=z1Rwwb2tWV=5P$##AOHafKmY;|fB*y_009U<00Iv{pi%$-vp;}a z;T8~p00bZa0SG_<0uX=z1Rwwb2tWV=5P$##AOHafKmY;|fB*y_009U<00Izz00bZa z0SG_<0uX=z1Rwwb2tWV=5P$##AOHafKmY;|fB*y_009V0B2X|}Y~Z1SN_X7nmXT=8 zC0*WtTM3H|VbL!lG$KP`Zy?%A8HGuTOpzh|9;tLIEwQK~J&Kd3+q{7m;qW`X0a31A zAeNeGzC{f7(#PpuAU~f+0JUZ4>%bOxZPJ(VleJ=5k zDW5*I{4q7B`OCHZNj*QU=g;VQhW_6*SfJ+@_3?7bsV?WVEOIIT7g1@^O{*B}07&&8PQ=Z$Y{?SJp&wRg`o_s%^( zRClKM)$)VK@Av&@&%SWNGkPU7&0FkPx_jo;k7TE1Wo-BH9jBBNXXbwYF}*$fwyST) U%}Z0cpEKEaiSrPE!2e0$Z;T;QWB>pF literal 0 HcmV?d00001 diff --git a/llvm-calc4/src/Calc/Patterns/Flatten.hs b/llvm-calc4/src/Calc/Patterns/Flatten.hs index a3e46a2b..2f1b174e 100644 --- a/llvm-calc4/src/Calc/Patterns/Flatten.hs +++ b/llvm-calc4/src/Calc/Patterns/Flatten.hs @@ -1,67 +1,90 @@ {-# LANGUAGE DerivingStrategies #-} + {-# LANGUAGE ScopedTypeVariables #-} +module Calc.Patterns.Flatten ( generateMissing) where -module Calc.Patterns.Flatten (SimpleExpr(..), - SimplePattern(..), flattenPatterns) where - -import Data.Bifunctor (first) -import Control.Monad (void) +import Debug.Trace +import Data.Functor (($>)) +import qualified Data.List import Calc.Types -import Data.List.NonEmpty as NE +import qualified Data.List.NonEmpty as NE + +-- | given our patterns, generate everything we need minus the ones we have +generateMissing :: (Eq ann, Show ann) => NE.NonEmpty (Pattern ann) -> [Pattern ann] +generateMissing nePats + = let pats = NE.toList nePats + in filterMissing pats (generatePatterns pats) + +-- | given our patterns, generate any others we might need +generatePatterns :: (Show ann) => [Pattern ann] -> [Pattern ann] +generatePatterns = concatMap generatePattern + +-- | given a pattern, generate all other patterns we'll need +generatePattern :: forall ann. (Show ann) => Pattern ann -> [Pattern ann] +generatePattern (PWildcard _) = mempty +generatePattern (PLiteral ann (PBool True)) = [PLiteral ann (PBool False)] +generatePattern (PLiteral ann (PBool False)) = [PLiteral ann (PBool True)] +generatePattern (PTuple ann a as) = + let genOrOriginal :: Pattern ann -> [Pattern ann] + genOrOriginal pat = + traceShowId $ case generatePattern (traceShowId pat) of + [] -> -- here we want to generate "everything" for the type to stop unnecessary wildcards + [pat] + pats -> if isTotal pat then pats else [pat] <> pats + + genAs :: [[Pattern ann]] + genAs = fmap genOrOriginal ([a] <> NE.toList as) --- we wish to flatten patterns like this --- --- case p of (1,2) -> True | (_,_) -> False --- --- becomes --- --- case p of --- [1,b] -> case b of --- 2 -> True --- _ -> False --- _ -> False + createTuple :: [Pattern ann] -> Pattern ann + createTuple items = + let ne = NE.fromList items + in PTuple ann (NE.head ne) (NE.fromList $ NE.tail ne) --- case p of (1,2,3) -> True | (_,_,_) -> False --- --- becomes --- --- case p of --- [1, b, c] -> case b of --- 2 -> case c of --- 3 -> True --- _ -> False --- _ -> False --- _ -> False + in fmap createTuple (sequence genAs) +generatePattern _ = mempty --- we'll identify missing patterns when we're unable to fill in the fallthrough --- cases +-- | wildcards are total, vars are total, products are total +isTotal :: Pattern ann -> Bool +isTotal (PWildcard _) = True +isTotal (PVar _ _) = True +isTotal (PTuple {}) = True +isTotal _ = False --- is this it? -data SimplePattern - = SPTuple Int - | SPPrim Prim - | SPWildcard - deriving stock (Eq,Ord,Show) +-- filter outstanding items +filterMissing :: + (Eq ann) => + [Pattern ann] -> + [Pattern ann] -> + [Pattern ann] +filterMissing patterns required = + Data.List.nub $ foldr annihiliatePattern required patterns + where + annihiliatePattern pat = + filter + ( not + . annihilate + (removeAnn pat) + . removeAnn + ) -data SimpleExpr - = SEPrim Prim - | SETupleItem Int SimpleExpr - | SEPatternMatch SimpleExpr [(SimplePattern, SimpleExpr)] - deriving stock (Eq,Ord,Show) +removeAnn :: Pattern ann -> Pattern () +removeAnn p = p $> () --- | this does not bind variables yet -flattenPatterns :: NE.NonEmpty (Pattern ann, Expr ann) -> [(SimplePattern, SimpleExpr )] -flattenPatterns = - fmap (uncurry flattenPattern . first unrollTuples) . NE.toList + {- +-- does left pattern satisfy right pattern? +annihilateAll :: + [(Pattern (), Pattern ())] -> + Bool +annihilateAll = + foldr + (\(a, b) keep -> keep && annihilate a b) + True +-} -unrollTuples :: Pattern ann -> [Pattern ann] -unrollTuples (PTuple _ a as) = [a] <> NE.toList as -unrollTuples other = [other] +-- | if left is on the right, should we get rid? +annihilate :: Pattern () -> Pattern () -> Bool +annihilate a b | a == b = True +annihilate (PWildcard _) _ = True -- wildcard trumps all +annihilate (PVar _ _) _ = True -- as does var +annihilate _ _as = False -flattenPattern :: [Pattern ann] -> Expr ann -> (SimplePattern, SimpleExpr) -flattenPattern [PWildcard _] expr = (SPWildcard, simpleExpr expr) -flattenPattern [PLiteral _ prim] expr = (SPPrim prim, simpleExpr expr) -flattenPattern other _ = error $ "flattenPattern " <> show (fmap void other) -simpleExpr :: Expr ann -> SimpleExpr -simpleExpr (EPrim _ prim) = SEPrim prim -simpleExpr other = error $ "simpleExpr " <> show (void other) diff --git a/llvm-calc4/test/Test/Helpers.hs b/llvm-calc4/test/Test/Helpers.hs index a0517d67..8a48f2a0 100644 --- a/llvm-calc4/test/Test/Helpers.hs +++ b/llvm-calc4/test/Test/Helpers.hs @@ -11,6 +11,7 @@ module Test.Helpers tyTuple, patTuple, patInt, + patBool ) where @@ -54,3 +55,6 @@ patTuple = \case patInt :: (Monoid ann) => Integer -> Pattern ann patInt = PLiteral mempty . PInt + +patBool :: (Monoid ann) => Bool -> Pattern ann +patBool = PLiteral mempty . PBool diff --git a/llvm-calc4/test/Test/Patterns/PatternsSpec.hs b/llvm-calc4/test/Test/Patterns/PatternsSpec.hs index dc1b2d75..ba895dd7 100644 --- a/llvm-calc4/test/Test/Patterns/PatternsSpec.hs +++ b/llvm-calc4/test/Test/Patterns/PatternsSpec.hs @@ -9,7 +9,8 @@ import Data.Text (Text) import Test.Hspec import qualified Data.List.NonEmpty as NE import Data.Bifunctor (bimap) -import Calc.Patterns.Flatten (flattenPatterns, SimpleExpr(..),SimplePattern(..)) +import Calc.Patterns.Flatten (generateMissing) +import Test.Helpers -- | try parsing the input, exploding if it's invalid unsafeParsePattern :: Text -> (Expr (), NE.NonEmpty (Pattern (), Expr ())) @@ -21,16 +22,48 @@ unsafeParsePattern input = case parseExprAndFormatError input of spec :: Spec spec = do - describe "PatternsSpec" $ do - it "Trivial wildcard pattern converts without trouble" $ do + fdescribe "PatternsSpec" $ do + it "Wildcard is exhaustive" $ do let (_expr,pats) = unsafeParsePattern "case a of _ -> 2" - flattenPatterns pats `shouldBe` [(SPWildcard, SEPrim (PInt 2))] - it "Trivial primitive pattern converts without trouble" $ do - let (_expr,pats) = unsafeParsePattern "case a of 1 -> 2 | _ -> 0" - flattenPatterns pats `shouldBe` [(SPPrim (PInt 1), SEPrim (PInt 2)), - (SPWildcard, SEPrim (PInt 0))] - it "Tuple is split into two patterns" $ do - let (_expr,pats) = unsafeParsePattern "case p of (1,2) -> True | (_,_) -> False" - flattenPatterns pats `shouldBe` [(SPPrim (PInt 1), SEPrim (PInt 2)), - (SPWildcard, SEPrim (PInt 0))] + generateMissing (fst <$> pats) `shouldBe` [] + it "True needs false" $ do + let (_expr,pats) = unsafeParsePattern "case a of True -> 2" + generateMissing (fst <$> pats) `shouldBe` [patBool False ] + it "False needs True" $ do + let (_expr,pats) = unsafeParsePattern "case a of False -> 2" + generateMissing (fst <$> pats) `shouldBe` [patBool True] + it "False and True needs nothing" $ do + let (_expr,pats) = unsafeParsePattern "case a of False -> 2 | True -> 1" + generateMissing (fst <$> pats) `shouldBe` [] + it "Tuple of two wildcards needs nothing " $ do + let (_expr,pats) = unsafeParsePattern "case a of (_,_) -> True" + generateMissing (fst <$> pats) `shouldBe` [] + it "Tuple of one wildcard and one true needs a false" $ do + let (_expr,pats) = unsafeParsePattern "case a of (_,True) -> True" + generateMissing (fst <$> pats) `shouldBe` [patTuple [PWildcard (), patBool False]] + it "Tuple of one true and one false needs a bunch" $ do + let (_expr,pats) = unsafeParsePattern "case a of (False,True) -> True" + generateMissing (fst <$> pats) `shouldBe` [patTuple [patBool False, patBool False], + patTuple [patBool True, patBool True], + patTuple [patBool True, patBool False]] + it "Tuple of booleans with some things supplied" $ do + let (_expr,pats) = unsafeParsePattern "case a of (False,True) -> True | (True, False) -> False" + generateMissing (fst <$> pats) `shouldBe` [patTuple [patBool False, patBool False], + patTuple [patBool True, patBool True]] + + it "Tuple of booleans with some things supplied" $ do + let (_expr,pats) = unsafeParsePattern "case a of (False,True) -> True | (True, False) -> False" + generateMissing (fst <$> pats) `shouldBe` [patTuple [patBool False, patBool False], + patTuple [patBool True, patBool True]] + + it "Tuple of wildcard and boolean" $ do + let (_expr,pats) = unsafeParsePattern "case a of (False,_) -> True | (True, False) -> False" + generateMissing (fst <$> pats) `shouldBe` [ + patTuple [patBool True, patBool True]] + + + + + + From 6ae0ccea7cf6b9c55620403c019176512646eb15 Mon Sep 17 00:00:00 2001 From: Daniel Harvey Date: Thu, 25 May 2023 17:20:12 +0100 Subject: [PATCH 11/13] Exhaustiveness working ish --- llvm-calc4/src/Calc/Patterns/Flatten.hs | 80 +++++++------- llvm-calc4/test/Test/Helpers.hs | 2 +- llvm-calc4/test/Test/Patterns/PatternsSpec.hs | 102 ++++++++++-------- 3 files changed, 100 insertions(+), 84 deletions(-) diff --git a/llvm-calc4/src/Calc/Patterns/Flatten.hs b/llvm-calc4/src/Calc/Patterns/Flatten.hs index 2f1b174e..715c0e5e 100644 --- a/llvm-calc4/src/Calc/Patterns/Flatten.hs +++ b/llvm-calc4/src/Calc/Patterns/Flatten.hs @@ -1,48 +1,60 @@ {-# LANGUAGE DerivingStrategies #-} - {-# LANGUAGE ScopedTypeVariables #-} -module Calc.Patterns.Flatten ( generateMissing) where +{-# LANGUAGE ScopedTypeVariables #-} -import Debug.Trace +module Calc.Patterns.Flatten (generateMissing) where + +import Calc.PatternUtils (getPatternAnnotation) +import Calc.Types import Data.Functor (($>)) import qualified Data.List -import Calc.Types import qualified Data.List.NonEmpty as NE -- | given our patterns, generate everything we need minus the ones we have -generateMissing :: (Eq ann, Show ann) => NE.NonEmpty (Pattern ann) -> [Pattern ann] -generateMissing nePats - = let pats = NE.toList nePats - in filterMissing pats (generatePatterns pats) +generateMissing :: (Show ann) => NE.NonEmpty (Pattern (Type ann)) -> [Pattern ()] +generateMissing nePats = + let pats = NE.toList nePats + in filterMissing (removeAnn <$> pats) (generatePatterns pats) -- | given our patterns, generate any others we might need -generatePatterns :: (Show ann) => [Pattern ann] -> [Pattern ann] +generatePatterns :: (Show ann) => [Pattern (Type ann)] -> [Pattern ()] generatePatterns = concatMap generatePattern +generateForType :: Type ann -> [Pattern ()] +generateForType (TPrim _ TBool) = [PLiteral () (PBool True), PLiteral () (PBool False)] +generateForType (TPrim _ TInt) = [PWildcard ()] -- too many, just do wildcard +generateForType _ = [PWildcard ()] + +typeIsTotal :: Type ann -> Bool +typeIsTotal (TPrim {}) = False +typeIsTotal (TTuple {}) = True +typeIsTotal (TFunction {}) = True + -- | given a pattern, generate all other patterns we'll need -generatePattern :: forall ann. (Show ann) => Pattern ann -> [Pattern ann] +generatePattern :: forall ann. (Show ann) => Pattern (Type ann) -> [Pattern ()] generatePattern (PWildcard _) = mempty -generatePattern (PLiteral ann (PBool True)) = [PLiteral ann (PBool False)] -generatePattern (PLiteral ann (PBool False)) = [PLiteral ann (PBool True)] -generatePattern (PTuple ann a as) = - let genOrOriginal :: Pattern ann -> [Pattern ann] +generatePattern (PLiteral _ (PBool True)) = [PLiteral () (PBool False)] +generatePattern (PLiteral _ (PBool False)) = [PLiteral () (PBool True)] +generatePattern (PTuple _ a as) = + let genOrOriginal :: Pattern (Type ann) -> [Pattern ()] genOrOriginal pat = - traceShowId $ case generatePattern (traceShowId pat) of - [] -> -- here we want to generate "everything" for the type to stop unnecessary wildcards - [pat] - pats -> if isTotal pat then pats else [pat] <> pats + case generatePattern pat of + [] -> + if typeIsTotal (getPatternAnnotation pat) + then [removeAnn pat] + else generateForType (getPatternAnnotation pat) + pats -> if isTotal pat then pats else [removeAnn pat] <> pats - genAs :: [[Pattern ann]] + genAs :: [[Pattern ()]] genAs = fmap genOrOriginal ([a] <> NE.toList as) - createTuple :: [Pattern ann] -> Pattern ann + createTuple :: [Pattern ()] -> Pattern () createTuple items = - let ne = NE.fromList items - in PTuple ann (NE.head ne) (NE.fromList $ NE.tail ne) - - in fmap createTuple (sequence genAs) + let ne = NE.fromList items + in PTuple () (NE.head ne) (NE.fromList $ NE.tail ne) + in fmap createTuple (sequence genAs) generatePattern _ = mempty --- | wildcards are total, vars are total, products are total +-- | wildcards are total, vars are total, products are total isTotal :: Pattern ann -> Bool isTotal (PWildcard _) = True isTotal (PVar _ _) = True @@ -51,25 +63,21 @@ isTotal _ = False -- filter outstanding items filterMissing :: - (Eq ann) => - [Pattern ann] -> - [Pattern ann] -> - [Pattern ann] + [Pattern ()] -> + [Pattern ()] -> + [Pattern ()] filterMissing patterns required = Data.List.nub $ foldr annihiliatePattern required patterns where annihiliatePattern pat = filter ( not - . annihilate - (removeAnn pat) - . removeAnn + . annihilate pat ) removeAnn :: Pattern ann -> Pattern () removeAnn p = p $> () - {- -- does left pattern satisfy right pattern? annihilateAll :: [(Pattern (), Pattern ())] -> @@ -78,13 +86,13 @@ annihilateAll = foldr (\(a, b) keep -> keep && annihilate a b) True --} -- | if left is on the right, should we get rid? annihilate :: Pattern () -> Pattern () -> Bool annihilate a b | a == b = True annihilate (PWildcard _) _ = True -- wildcard trumps all annihilate (PVar _ _) _ = True -- as does var +annihilate (PTuple _ a as) (PTuple _ b bs) = + let allPairs = zip ([a] <> NE.toList as) ([b] <> NE.toList bs) + in annihilateAll allPairs annihilate _ _as = False - - diff --git a/llvm-calc4/test/Test/Helpers.hs b/llvm-calc4/test/Test/Helpers.hs index 8a48f2a0..af4adff2 100644 --- a/llvm-calc4/test/Test/Helpers.hs +++ b/llvm-calc4/test/Test/Helpers.hs @@ -11,7 +11,7 @@ module Test.Helpers tyTuple, patTuple, patInt, - patBool + patBool, ) where diff --git a/llvm-calc4/test/Test/Patterns/PatternsSpec.hs b/llvm-calc4/test/Test/Patterns/PatternsSpec.hs index ba895dd7..b84dea73 100644 --- a/llvm-calc4/test/Test/Patterns/PatternsSpec.hs +++ b/llvm-calc4/test/Test/Patterns/PatternsSpec.hs @@ -1,69 +1,77 @@ {-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE TypeApplications #-} module Test.Patterns.PatternsSpec (spec) where import Calc -import Control.Monad (void) -import Data.Functor (($>)) -import Data.Text (Text) -import Test.Hspec -import qualified Data.List.NonEmpty as NE -import Data.Bifunctor (bimap) import Calc.Patterns.Flatten (generateMissing) +import qualified Data.List.NonEmpty as NE import Test.Helpers - --- | try parsing the input, exploding if it's invalid -unsafeParsePattern :: Text -> (Expr (), NE.NonEmpty (Pattern (), Expr ())) -unsafeParsePattern input = case parseExprAndFormatError input of - Right (EPatternMatch _ expr pats) -> - (expr $> (), fmap (bimap void void) pats ) - Right other -> error $ "expected pattern match, got " <> show other - Left e -> error (show e) +import Test.Hspec spec :: Spec spec = do fdescribe "PatternsSpec" $ do it "Wildcard is exhaustive" $ do - let (_expr,pats) = unsafeParsePattern "case a of _ -> 2" - generateMissing (fst <$> pats) `shouldBe` [] + let pats = NE.fromList [PWildcard tyInt] + generateMissing @() pats `shouldBe` [] it "True needs false" $ do - let (_expr,pats) = unsafeParsePattern "case a of True -> 2" - generateMissing (fst <$> pats) `shouldBe` [patBool False ] + let pats = NE.fromList [PLiteral tyBool (PBool True)] + generateMissing @() pats `shouldBe` [patBool False] it "False needs True" $ do - let (_expr,pats) = unsafeParsePattern "case a of False -> 2" - generateMissing (fst <$> pats) `shouldBe` [patBool True] + let pats = NE.fromList [PLiteral tyBool (PBool False)] + generateMissing @() pats `shouldBe` [patBool True] it "False and True needs nothing" $ do - let (_expr,pats) = unsafeParsePattern "case a of False -> 2 | True -> 1" - generateMissing (fst <$> pats) `shouldBe` [] + let pats = NE.fromList [PLiteral tyBool (PBool False), PLiteral tyBool (PBool True)] + generateMissing @() pats `shouldBe` [] it "Tuple of two wildcards needs nothing " $ do - let (_expr,pats) = unsafeParsePattern "case a of (_,_) -> True" - generateMissing (fst <$> pats) `shouldBe` [] + let pats = + NE.fromList + [ PTuple + (tyTuple [tyBool, tyBool]) + (PWildcard tyBool) + (NE.fromList [PWildcard tyBool]) + ] + generateMissing @() pats `shouldBe` [] + it "Tuple of one wildcard and one true needs a false" $ do - let (_expr,pats) = unsafeParsePattern "case a of (_,True) -> True" - generateMissing (fst <$> pats) `shouldBe` [patTuple [PWildcard (), patBool False]] + let pats = + NE.fromList + [ PTuple (tyTuple [tyBool, tyBool]) (PWildcard tyBool) (NE.fromList [PLiteral tyBool (PBool True)]) + ] + generateMissing @() pats + `shouldBe` [ patTuple [patBool True, patBool False], + patTuple [patBool False, patBool False] + ] + it "Tuple of one true and one false needs a bunch" $ do - let (_expr,pats) = unsafeParsePattern "case a of (False,True) -> True" - generateMissing (fst <$> pats) `shouldBe` [patTuple [patBool False, patBool False], - patTuple [patBool True, patBool True], - patTuple [patBool True, patBool False]] - it "Tuple of booleans with some things supplied" $ do - let (_expr,pats) = unsafeParsePattern "case a of (False,True) -> True | (True, False) -> False" - generateMissing (fst <$> pats) `shouldBe` [patTuple [patBool False, patBool False], - patTuple [patBool True, patBool True]] + let pats = + NE.fromList + [ PTuple (tyTuple [tyBool, tyBool]) (PLiteral tyBool (PBool False)) (NE.fromList [PLiteral tyBool (PBool True)]) + ] + generateMissing @() pats + `shouldBe` [ patTuple [patBool False, patBool False], + patTuple [patBool True, patBool True], + patTuple [patBool True, patBool False] + ] it "Tuple of booleans with some things supplied" $ do - let (_expr,pats) = unsafeParsePattern "case a of (False,True) -> True | (True, False) -> False" - generateMissing (fst <$> pats) `shouldBe` [patTuple [patBool False, patBool False], - patTuple [patBool True, patBool True]] + let pats = + NE.fromList + [ PTuple (tyTuple [tyBool, tyBool]) (PLiteral tyBool (PBool False)) (NE.fromList [PLiteral tyBool (PBool True)]), + PTuple (tyTuple [tyBool, tyBool]) (PLiteral tyBool (PBool True)) (NE.fromList [PLiteral tyBool (PBool False)]) + ] + generateMissing @() pats + `shouldBe` [ patTuple [patBool False, patBool False], + patTuple [patBool True, patBool True] + ] it "Tuple of wildcard and boolean" $ do - let (_expr,pats) = unsafeParsePattern "case a of (False,_) -> True | (True, False) -> False" - generateMissing (fst <$> pats) `shouldBe` [ - patTuple [patBool True, patBool True]] - - - - - - - + let pats = + NE.fromList + [ PTuple (tyTuple [tyBool, tyBool]) (PLiteral tyBool (PBool False)) (NE.fromList [PWildcard tyBool]), + PTuple (tyTuple [tyBool, tyBool]) (PLiteral tyBool (PBool True)) (NE.fromList [PLiteral tyBool (PBool False)]) + ] + generateMissing @() pats + `shouldBe` [ patTuple [patBool True, patBool True] + ] From abed1c1ec872e0362d69bfa763f30d85f2c64f72 Mon Sep 17 00:00:00 2001 From: Daniel Harvey Date: Fri, 26 May 2023 15:34:21 +0100 Subject: [PATCH 12/13] Do completeness checking --- llvm-calc4/a.out | Bin 51200 -> 0 bytes llvm-calc4/llvm-calc4.cabal | 7 ++++- llvm-calc4/src/Calc/Patterns/Flatten.hs | 6 ++--- llvm-calc4/src/Calc/Typecheck/Elaborate.hs | 6 ++++- llvm-calc4/src/Calc/Typecheck/Error.hs | 25 ++++++++++++++++++ llvm-calc4/src/Calc/Types/Annotation.hs | 5 ++-- llvm-calc4/test/Test/Patterns/PatternsSpec.hs | 2 +- 7 files changed, 43 insertions(+), 8 deletions(-) delete mode 100755 llvm-calc4/a.out diff --git a/llvm-calc4/a.out b/llvm-calc4/a.out deleted file mode 100755 index b31b2ecaeaba0cc7253cc336d693c087080102ce..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 51200 zcmeI*ZERCj7zgmvc5j3+H+*3#x(ZCd&~+8IK|maAgMqlgSj>D$&f4{E?daS3vaO7b zCJ4c3;xGn?35c2?izKpgo@e3rz|C5~F z^W1apbI<+tu6#Or?&cr2a)mGn3Ny9o)Dl@jY!_DQ3Gq0!T53|RU$MI8wVH-{KARfq z#i^(6$hgitx`C99HBF7F!|L^wspo{AXVPhIv?Mj8j4ItxJ)uK-y}{+$!x&RI=jWP; z43U`VQ%A~3H0Fv-oJy}(`l42kjMU&fxn7Q5&luP0Nsrh%`u7jbS$E|Ave_hc_MDXUjBu8>vrbsM#}rBZ)M>8M89k!y0} zcPXEmlx;!d9iVj5dO1AEQJK0fe&3ScYhKrH)Ug1jl>Fdi$XzMfTaegoVS}G;wx;1q#)l{$1 z^o1Bsdrr#L&53bCLZsnoab9dglMMY zRBG#~aTxb7eh-ozv>CTQ^XdA#JtzO2SgG;ft34J8+kIY_-5vM2Dd%{enw1*gOS}72 zwXfB3nxCVy{C*iRkDim)eVTq}l7s6nrp9BuPSdhnTDIx-J?Xw1CTfVT| zW20i;mSCXMR#vJ$o{=ZtXC|$?TfbjY?YOVIrOeq;R^Hy>b}Z@iyE@uJZLWBGpxxiK zq$BROIYS|zVpIR%yMp0OwklfDs$eWYZ=p9BusIwhr6r~IbzWE48IIdm@ec#foA`l! zmhQ=ZJBPLJdA^?W_s;PYH6BmU_s9PZAs_$&2tWV=5P$##AOHafKmY;|fB*y_009U< z00Izz00bZa0SG_<0uX=z1Rwwb2tWV=5P$##AOHafKmY;|fB*y_009U<00I#Bj|pUm zf?Y)KX zewZu`&fA&>E&DFYH(kT>b(5@ zrhWgkeb$(XBLpA-0SG_<0uX=z1Rwwb2tWV=5P$##AOHafKmY;|fB*y_009U<00Izz z00bZa0SG_<0uX=z1Rwwb2tWV=5P$##AOHafKmY;|fB*y_009U<00Iv{pi%$-vp;}a z;T8~p00bZa0SG_<0uX=z1Rwwb2tWV=5P$##AOHafKmY;|fB*y_009U<00Izz00bZa z0SG_<0uX=z1Rwwb2tWV=5P$##AOHafKmY;|fB*y_009V0B2X|}Y~Z1SN_X7nmXT=8 zC0*WtTM3H|VbL!lG$KP`Zy?%A8HGuTOpzh|9;tLIEwQK~J&Kd3+q{7m;qW`X0a31A zAeNeGzC{f7(#PpuAU~f+0JUZ4>%bOxZPJ(VleJ=5k zDW5*I{4q7B`OCHZNj*QU=g;VQhW_6*SfJ+@_3?7bsV?WVEOIIT7g1@^O{*B}07&&8PQ=Z$Y{?SJp&wRg`o_s%^( zRClKM)$)VK@Av&@&%SWNGkPU7&0FkPx_jo;k7TE1Wo-BH9jBBNXXbwYF}*$fwyST) U%}Z0cpEKEaiSrPE!2e0$Z;T;QWB>pF diff --git a/llvm-calc4/llvm-calc4.cabal b/llvm-calc4/llvm-calc4.cabal index ea859449..6bb3cf42 100644 --- a/llvm-calc4/llvm-calc4.cabal +++ b/llvm-calc4/llvm-calc4.cabal @@ -61,10 +61,13 @@ common shared Calc.Parser.Function Calc.Parser.Identifier Calc.Parser.Module + Calc.Parser.Pattern Calc.Parser.Primitives Calc.Parser.Shared Calc.Parser.Type Calc.Parser.Types + Calc.Patterns.Flatten + Calc.PatternUtils Calc.Repl Calc.SourceSpan Calc.Typecheck.Elaborate @@ -77,9 +80,11 @@ common shared Calc.Types.FunctionName Calc.Types.Identifier Calc.Types.Module + Calc.Types.Pattern Calc.Types.Prim Calc.Types.Type Calc.TypeUtils + Calc.Utils library import: shared @@ -110,8 +115,8 @@ test-suite llvm-calc4-tests Test.Typecheck.TypecheckSpec executable llvm-calc4 - main-is: Main.hs import: shared + main-is: Main.hs hs-source-dirs: app hs-source-dirs: src ghc-options: -threaded -rtsopts -with-rtsopts=-N diff --git a/llvm-calc4/src/Calc/Patterns/Flatten.hs b/llvm-calc4/src/Calc/Patterns/Flatten.hs index 715c0e5e..80fa90a1 100644 --- a/llvm-calc4/src/Calc/Patterns/Flatten.hs +++ b/llvm-calc4/src/Calc/Patterns/Flatten.hs @@ -10,13 +10,13 @@ import qualified Data.List import qualified Data.List.NonEmpty as NE -- | given our patterns, generate everything we need minus the ones we have -generateMissing :: (Show ann) => NE.NonEmpty (Pattern (Type ann)) -> [Pattern ()] +generateMissing :: NE.NonEmpty (Pattern (Type ann)) -> [Pattern ()] generateMissing nePats = let pats = NE.toList nePats in filterMissing (removeAnn <$> pats) (generatePatterns pats) -- | given our patterns, generate any others we might need -generatePatterns :: (Show ann) => [Pattern (Type ann)] -> [Pattern ()] +generatePatterns :: [Pattern (Type ann)] -> [Pattern ()] generatePatterns = concatMap generatePattern generateForType :: Type ann -> [Pattern ()] @@ -30,7 +30,7 @@ typeIsTotal (TTuple {}) = True typeIsTotal (TFunction {}) = True -- | given a pattern, generate all other patterns we'll need -generatePattern :: forall ann. (Show ann) => Pattern (Type ann) -> [Pattern ()] +generatePattern :: forall ann. Pattern (Type ann) -> [Pattern ()] generatePattern (PWildcard _) = mempty generatePattern (PLiteral _ (PBool True)) = [PLiteral () (PBool False)] generatePattern (PLiteral _ (PBool False)) = [PLiteral () (PBool True)] diff --git a/llvm-calc4/src/Calc/Typecheck/Elaborate.hs b/llvm-calc4/src/Calc/Typecheck/Elaborate.hs index c34be1ae..a095c2eb 100644 --- a/llvm-calc4/src/Calc/Typecheck/Elaborate.hs +++ b/llvm-calc4/src/Calc/Typecheck/Elaborate.hs @@ -4,6 +4,7 @@ module Calc.Typecheck.Elaborate (elaborate, elaborateFunction, elaborateModule) where +import Calc.Patterns.Flatten import Calc.ExprUtils import Calc.PatternUtils import Calc.TypeUtils @@ -181,7 +182,7 @@ infer (ETuple ann fstExpr restExpr) = do (getOuterAnnotation typedFst) (getOuterAnnotation <$> typedRest) pure $ ETuple typ typedFst typedRest -infer (EPatternMatch _ann matchExpr pats) = do +infer (EPatternMatch ann matchExpr pats) = do elabExpr <- infer matchExpr let withPair (pat, patExpr) = do (elabPat, newVars) <- checkPattern (getOuterAnnotation elabExpr) pat @@ -190,6 +191,9 @@ infer (EPatternMatch _ann matchExpr pats) = do elabPats <- traverse withPair pats let allTypes = getOuterAnnotation . snd <$> elabPats typ <- checkTypesAreEqual allTypes + case generateMissing (fst <$> elabPats) of + [] -> pure () + missingPatterns -> throwError (IncompletePatterns ann missingPatterns) pure (EPatternMatch typ elabExpr elabPats) infer (EApply ann fnName args) = do fn <- lookupFunction ann fnName diff --git a/llvm-calc4/src/Calc/Typecheck/Error.hs b/llvm-calc4/src/Calc/Typecheck/Error.hs index 5c5887b1..559293a5 100644 --- a/llvm-calc4/src/Calc/Typecheck/Error.hs +++ b/llvm-calc4/src/Calc/Typecheck/Error.hs @@ -31,6 +31,7 @@ data TypeError ann | PatternMismatch (Pattern ann) (Type ann) | FunctionArgumentLengthMismatch ann Int Int -- expected, actual | NonFunctionTypeFound ann (Type ann) + | IncompletePatterns ann [Pattern ()] deriving stock (Eq, Ord, Show) positionFromAnnotation :: @@ -210,6 +211,24 @@ typeErrorDiagnostic input e = ] ) [Diag.Note $ "Available in scope: " <> prettyPrint (prettyHashset existing)] + (IncompletePatterns ann missingPatterns) -> + Diag.Err + Nothing + "Pattern match is incomplete!" + ( catMaybes + [ (,) + <$> positionFromAnnotation + filename + input + ann + <*> pure + ( Diag.This $ + prettyPrint $ "Missing patterns: " <> PP.line <> prettyListToLines missingPatterns + ) + ] + ) + [] + in Diag.addReport diag report -- | becomes "a, b, c, d" @@ -219,6 +238,12 @@ prettyHashset hs = (PP.surround PP.comma) (PP.pretty <$> HS.toList hs) +prettyListToLines :: (PP.Pretty a) => [a] -> PP.Doc ann +prettyListToLines as + = PP.concatWith + (PP.surround PP.line) + (PP.pretty <$> as) + renderWithWidth :: Int -> PP.Doc ann -> Text renderWithWidth w doc = PP.renderStrict (PP.layoutPretty layoutOptions (PP.unAnnotate doc)) where diff --git a/llvm-calc4/src/Calc/Types/Annotation.hs b/llvm-calc4/src/Calc/Types/Annotation.hs index 947b9d84..fd54572e 100644 --- a/llvm-calc4/src/Calc/Types/Annotation.hs +++ b/llvm-calc4/src/Calc/Types/Annotation.hs @@ -11,9 +11,10 @@ where data Annotation = Location Int Int deriving stock (Eq, Ord, Show) --- | when combining two `Annotation`, take the first one +-- | when combining two `Annotation`, combine to make one big annotation instance Semigroup Annotation where - a <> _ = a + (Location start end) <> (Location start' end') = + Location (min start start') (max end end') -- | Default to an empty `Annotation` instance Monoid Annotation where diff --git a/llvm-calc4/test/Test/Patterns/PatternsSpec.hs b/llvm-calc4/test/Test/Patterns/PatternsSpec.hs index b84dea73..85ea3324 100644 --- a/llvm-calc4/test/Test/Patterns/PatternsSpec.hs +++ b/llvm-calc4/test/Test/Patterns/PatternsSpec.hs @@ -11,7 +11,7 @@ import Test.Hspec spec :: Spec spec = do - fdescribe "PatternsSpec" $ do + describe "PatternsSpec" $ do it "Wildcard is exhaustive" $ do let pats = NE.fromList [PWildcard tyInt] generateMissing @() pats `shouldBe` [] From 509b888bb6e1c18bd96a941f4ddc04bcb04040ae Mon Sep 17 00:00:00 2001 From: Daniel Harvey Date: Fri, 2 Jun 2023 09:29:15 +0100 Subject: [PATCH 13/13] Broken test --- llvm-calc4/src/Calc/Typecheck/Elaborate.hs | 2 +- llvm-calc4/src/Calc/Typecheck/Error.hs | 14 +++++++------- llvm-calc4/src/Calc/Types/Annotation.hs | 2 +- llvm-calc4/test/Test/LLVM/LLVMSpec.hs | 20 ++++++++++++++++++-- 4 files changed, 27 insertions(+), 11 deletions(-) diff --git a/llvm-calc4/src/Calc/Typecheck/Elaborate.hs b/llvm-calc4/src/Calc/Typecheck/Elaborate.hs index a095c2eb..ba3fa212 100644 --- a/llvm-calc4/src/Calc/Typecheck/Elaborate.hs +++ b/llvm-calc4/src/Calc/Typecheck/Elaborate.hs @@ -4,9 +4,9 @@ module Calc.Typecheck.Elaborate (elaborate, elaborateFunction, elaborateModule) where -import Calc.Patterns.Flatten import Calc.ExprUtils import Calc.PatternUtils +import Calc.Patterns.Flatten import Calc.TypeUtils import Calc.Typecheck.Error import Calc.Typecheck.Types diff --git a/llvm-calc4/src/Calc/Typecheck/Error.hs b/llvm-calc4/src/Calc/Typecheck/Error.hs index 559293a5..ae4b879d 100644 --- a/llvm-calc4/src/Calc/Typecheck/Error.hs +++ b/llvm-calc4/src/Calc/Typecheck/Error.hs @@ -222,13 +222,13 @@ typeErrorDiagnostic input e = input ann <*> pure - ( Diag.This $ - prettyPrint $ "Missing patterns: " <> PP.line <> prettyListToLines missingPatterns + ( Diag.This $ + prettyPrint $ + "Missing patterns: " <> PP.line <> prettyListToLines missingPatterns ) ] ) [] - in Diag.addReport diag report -- | becomes "a, b, c, d" @@ -239,10 +239,10 @@ prettyHashset hs = (PP.pretty <$> HS.toList hs) prettyListToLines :: (PP.Pretty a) => [a] -> PP.Doc ann -prettyListToLines as - = PP.concatWith - (PP.surround PP.line) - (PP.pretty <$> as) +prettyListToLines as = + PP.concatWith + (PP.surround PP.line) + (PP.pretty <$> as) renderWithWidth :: Int -> PP.Doc ann -> Text renderWithWidth w doc = PP.renderStrict (PP.layoutPretty layoutOptions (PP.unAnnotate doc)) diff --git a/llvm-calc4/src/Calc/Types/Annotation.hs b/llvm-calc4/src/Calc/Types/Annotation.hs index fd54572e..28992bdc 100644 --- a/llvm-calc4/src/Calc/Types/Annotation.hs +++ b/llvm-calc4/src/Calc/Types/Annotation.hs @@ -14,7 +14,7 @@ data Annotation = Location Int Int -- | when combining two `Annotation`, combine to make one big annotation instance Semigroup Annotation where (Location start end) <> (Location start' end') = - Location (min start start') (max end end') + Location (min start start') (max end end') -- | Default to an empty `Annotation` instance Monoid Annotation where diff --git a/llvm-calc4/test/Test/LLVM/LLVMSpec.hs b/llvm-calc4/test/Test/LLVM/LLVMSpec.hs index c8666127..55e203c8 100644 --- a/llvm-calc4/test/Test/LLVM/LLVMSpec.hs +++ b/llvm-calc4/test/Test/LLVM/LLVMSpec.hs @@ -28,6 +28,9 @@ testCompileIR (input, result) = it (show input) $ do resp <- run (moduleToLLVM typedExpr) resp `shouldBe` result +joinLines :: [Text] -> Text +joinLines = foldr (\a b -> a <> " " <> b) "" + spec :: Spec spec = do describe "LLVMSpec" $ do @@ -40,10 +43,23 @@ spec = do ("if False then 1 else 2", "2"), ("if 1 == 1 then 7 else 10", "7"), ("if 2 == 1 then True else False", "False"), - ("function one() { 1 } function two() { 2 } one() + two()", "3"), + ( joinLines + [ "function one() { 1 }", + "function two() { 2 }", + "one() + two()" + ], + "3" + ), ("function increment(a: Integer) { a + 1 } increment(41)", "42"), ("function sum(a: Integer, b: Integer) { a + b } sum(20,22)", "42"), - ("function inc(a: Integer) { a + 1 } inc(inc(inc(inc(0))))", "4") + ("function inc(a: Integer) { a + 1 } inc(inc(inc(inc(0))))", "4"), + ( joinLines + [ "function swapIntAndBool(pair: (Integer, Boolean)) { case pair of (a, b) -> (b, a) }", + "function fst(pair: (Boolean, Integer)) { case pair of (a,_) -> a }", + "fst(swapIntAndBool((1,True)))" + ], + "True" -- note we cannot make polymorphic versions of these functions yet, although we will + ) ] describe "From modules" $ do