Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions src/Data/V/Linear.hs
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,8 @@
-- isTrue :: Bool
-- isTrue = V.elim doSomething (build 4 9)
-- where
-- -- GHC can't figure out this type equality, so this is needed.
-- build :: Int %1-> Int %1-> V.V 2 Int
-- build = V.make @2 @Int
-- build = V.make
-- :}
--
-- A much more expensive library of vectors of known size (including matrices
Expand All @@ -48,8 +47,8 @@ module Data.V.Linear
fromReplicator,
dupV,
theLength,
Make,
make,
FunN,
)
where

Expand Down
123 changes: 54 additions & 69 deletions src/Data/V/Linear/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -34,25 +34,22 @@ module Data.V.Linear.Internal
fromReplicator,
dupV,
theLength,
Make,
make,
FunN,
)
where

import Data.Arity.Linear.Internal.Arity
import Data.Kind
import Data.Proxy
import Data.Replicator.Linear.Internal (Replicator)
import qualified Data.Replicator.Linear.Internal as Replicator
import Data.Type.Equality
import Data.Unrestricted.Internal.Dupable (Dupable (dupR))
import Data.Vector (Vector)
import qualified Data.Vector as Vector
import GHC.Exts (proxy#)
import GHC.TypeLits
import Prelude.Linear.Internal
import qualified Unsafe.Linear as Unsafe
import Prelude (Bool (..), Either (..), Maybe (..), error, (-))
import qualified Prelude

-- | @'V' n a@ represents an immutable sequence of @n@ elements of type @a@
Expand All @@ -72,9 +69,6 @@ consume = Unsafe.toLinear (\_ -> ())
map :: (a %1 -> b) -> V n a %1 -> V n b
map f (V xs) = V $ Unsafe.toLinear (Vector.map (\x -> f x)) xs

pure :: forall n a. KnownNat n => a -> V n a
pure a = V $ Vector.replicate (theLength @n) a

(<*>) :: V n (a %1 -> b) %1 -> V n a %1 -> V n b
(V fs) <*> (V xs) =
V $
Expand All @@ -96,6 +90,8 @@ uncons = Unsafe.toLinear uncons'

-- | @'Elim' n a b f@ asserts that @f@ is a function taking @n@ linear arguments
-- of type @a@ and then returning a value of type @b@.
--
-- It is solely used to define the type of the 'elim' function.
type Elim :: Nat -> Type -> Type -> Type -> Constraint
class (n ~ Arity b f) => Elim n a b f | n a b -> f, f b -> n where
-- | Takes a function of type @a %1 -> a %1 -> ... %1 -> a %1 -> b@, and
Expand Down Expand Up @@ -125,13 +121,47 @@ instance (1 <= n, n ~ Arity b (a %1 -> f), Elim (n - 1) a b f) => Elim n a b (a
cons :: forall n a. a %1 -> V (n - 1) a %1 -> V n a
cons = Unsafe.toLinear2 $ \x (V v) -> V (Vector.cons x v)

-- | Creates a 'V' of the specified size by consuming a 'Replicator'.
fromReplicator :: forall n a. (KnownNat n, Replicator.Elim n a (V n a) (FunN n a (V n a))) => Replicator a %1 -> V n a
fromReplicator = Replicator.elim @n @a @(V n a) @(FunN n a (V n a)) (make @n @a)

-- | Produces a @'V' n a@ from a 'Dupable' value @a@.
dupV :: forall n a. (KnownNat n, Dupable a, Replicator.Elim n a (V n a) (FunN n a (V n a))) => a %1 -> V n a
dupV = fromReplicator . dupR
-- | Builds a n-ary constructor for @'V' n a@ (i.e. a function taking @n@ linear
-- arguments of type @a@ and returning a @'V' n a@).
--
-- > myV :: V 3 Int
-- > myV = make 1 2 3
make :: forall n a f. Make n n a f => f
make = make' @n @n id
{-# INLINE make #-}

-- | @'Make' m n a f@ asserts that @f@ is a function taking @m@ linear arguments
-- of type @a@ and then returning a value of type @'V' n a@.
--
-- It is solely used to define the type of the 'make' function.
type Make :: Nat -> Nat -> Type -> Type -> Constraint
class (m ~ Arity (V n a) f) => Make m n a f | f -> m n a where
-- The idea behind Make / make' / make is the following:
--
-- f takes m values of type a, but returns a 'V n a' (with n ≥ m),
-- so the n - m missing values must be supplied via the accumulator.
--
-- @make' is initially called with m = n via make, and as m decreases,
-- the number of lambda on the left increases and the captured values are put
-- in the accumulator
-- ('V[ ... ] <> ' represents the "extend" operation for 'V'):
--
-- > make @n
-- > --> make' @n @n (V[] <>)
-- > --> λx. make' @(n - 1) @n (V[x] <>)
-- > --> λx. λy. make' @(n - 2) @n (V[x, y] <>)
-- > --> ...
-- > --> λx. λy. ... λz. make' @0 @n (V[x, y, ... z] <>) -- make' @0 @n f = f V[]
-- > --> λx. λy. ... λz. V[x, y, ... z]
make' :: (V m a %1 -> V n a) %1 -> f

instance Make 0 n a (V n a) where
make' produceFrom = produceFrom (V Vector.empty)
{-# INLINE make' #-}

instance (m ~ Arity (V n a) (a %1 -> f), Make (m - 1) n a f) => Make m n a (a %1 -> f) where
make' produceFrom = \x -> make' @(m - 1) @n @a (\s -> produceFrom $ cons x s)
{-# INLINE make' #-}

-------------------------------------------------------------------------------
-- Functions below use AllowAmbiguousTypes
Expand All @@ -141,58 +171,13 @@ dupV = fromReplicator . dupR
theLength :: forall n. KnownNat n => Prelude.Int
theLength = Prelude.fromIntegral (natVal' @n (proxy# @_))

-- Make implementation, which needs to be improved

-- | Builds a n-ary constructor for @'V' n a@ (i.e. a function taking @n@
-- elements of type @a@ and returning a @'V' n a@).
make :: forall n a. KnownNat n => FunN n a (V n a)
make = case caseNat @n of
Left Refl -> V Vector.empty
Right Refl -> contractFunN @n @a @(V n a) prepend
where
prepend :: a %1 -> FunN (n - 1) a (V n a)
prepend t = case predNat @n of
Dict -> continue @(n - 1) @a @(V (n - 1) a) (cons t) (make @(n - 1) @a)

-- Helper functions/types for 'make' to typecheck

data Dict (c :: Constraint) where
Dict :: c => Dict c

type family FunN (n :: Nat) (a :: Type) (b :: Type) :: Type where
FunN 0 a b = b
FunN n a b = a %1 -> FunN (n - 1) a b

predNat :: forall n. (1 <= n, KnownNat n) => Dict (KnownNat (n - 1))
predNat = case someNatVal (natVal' @n (proxy# @_) - 1) of
Just (SomeNat (_ :: Proxy p)) -> Unsafe.coerce (Dict @(KnownNat p))
Nothing -> error "Vector.pred: n-1 is necessarily a Nat, if 1<=n"

caseNat :: forall n. KnownNat n => Either (n :~: 0) ((1 <=? n) :~: 'True)
caseNat =
case theLength @n of
0 -> Left $ unsafeZero @n
_ -> Right $ unsafeNonZero @n
{-# INLINE caseNat #-}

-- By definition.
expandFunN :: forall n a b. (1 <= n) => FunN n a b %1 -> a %1 -> FunN (n - 1) a b
expandFunN k = Unsafe.coerce k

-- By definition.
contractFunN :: (1 <= n) => (a %1 -> FunN (n - 1) a b) %1 -> FunN n a b
contractFunN k = Unsafe.coerce k

continue :: forall n a b c. KnownNat n => (b %1 -> c) %1 -> FunN n a b %1 -> FunN n a c
continue = case caseNat @n of
Left Refl -> id
Right Refl -> \f t -> contractFunN @n @a @c (continueS f (expandFunN @n @a @b t))
where
continueS :: (KnownNat n, 1 <= n) => (b %1 -> c) %1 -> (a %1 -> FunN (n - 1) a b) %1 -> (a %1 -> FunN (n - 1) a c)
continueS f' x a = case predNat @n of Dict -> continue @(n - 1) @a @b f' (x a)

unsafeZero :: n :~: 0
unsafeZero = Unsafe.coerce Refl

unsafeNonZero :: (1 <=? n) :~: 'True
unsafeNonZero = Unsafe.coerce Refl
pure :: forall n a. KnownNat n => a -> V n a
pure a = V $ Vector.replicate (theLength @n) a

-- | Creates a 'V' of the specified size by consuming a 'Replicator'.
fromReplicator :: forall n a. KnownNat n => Replicator a %1 -> V n a
fromReplicator = let n' = theLength @n in V . Unsafe.toLinear Vector.fromList . Replicator.take n'

-- | Produces a @'V' n a@ from a 'Dupable' value @a@.
dupV :: forall n a. (KnownNat n, Dupable a) => a %1 -> V n a
dupV = fromReplicator . dupR