module Data.Type.Vector where
import Data.Type.Combinator
import Data.Type.Fin
import Data.Type.Length
import Data.Type.Nat
import Data.Type.Product (Prod(..),curry',pattern (:>))
import Type.Class.Known
import Type.Class.Witness
import Type.Family.Constraint
import Type.Family.List
import Type.Family.Nat
import qualified Data.List as L
import Data.Monoid
data VT (n :: N) (f :: k -> *) :: k -> * where
ØV :: VT Z f a
(:*) :: !(f a) -> !(VT n f a) -> VT (S n) f a
infixr 4 :*
elimVT :: p Z
-> (forall x. f a -> p x -> p (S x))
-> VT n f a
-> p n
elimVT z s = \case
ØV -> z
a :* as -> s a $ elimVT z s as
elimV :: p Z
-> (forall x. a -> p x -> p (S x))
-> V n a
-> p n
elimV z s = elimVT z $ s . getI
type V n = VT n I
pattern (:+) :: a -> V n a -> V (S n) a
pattern a :+ as = I a :* as
infixr 4 :+
deriving instance Eq (f a) => Eq (VT n f a)
deriving instance Ord (f a) => Ord (VT n f a)
deriving instance Show (f a) => Show (VT n f a)
(.++) :: VT x f a -> VT y f a -> VT (x + y) f a
(.++) = \case
ØV -> id
a :* as -> (a :*) . (as .++)
infixr 5 .++
vrep :: forall n f a. Known Nat n => f a -> VT n f a
vrep a = go (known :: Nat n)
where
go :: Nat x -> VT x f a
go = \case
Z_ -> ØV
S_ x -> a :* go x
head' :: VT (S n) f a -> f a
head' (a :* _) = a
tail' :: VT (S n) f a -> VT n f a
tail' (_ :* as) = as
onTail :: (VT m f a -> VT n f a) -> VT (S m) f a -> VT (S n) f a
onTail f (a :* as) = a :* f as
vDel :: Fin n -> VT n f a -> VT (Pred n) f a
vDel = \case
FZ -> tail'
FS x -> onTail (vDel x) \\ x
imap :: (Fin n -> f a -> g b) -> VT n f a -> VT n g b
imap f = \case
ØV -> ØV
a :* as -> f FZ a :* imap (f . FS) as
ifoldMap :: Monoid m => (Fin n -> f a -> m) -> VT n f a -> m
ifoldMap f = \case
ØV -> mempty
a :* as -> f FZ a <> ifoldMap (f . FS) as
itraverse :: Applicative h => (Fin n -> f a -> h (g b)) -> VT n f a -> h (VT n g b)
itraverse f = \case
ØV -> pure ØV
a :* as -> (:*) <$> f FZ a <*> itraverse (f . FS) as
index :: Fin n -> VT n f a -> f a
index = \case
FZ -> head'
FS x -> index x . tail'
vmap :: (f a -> g b) -> VT n f a -> VT n g b
vmap f = \case
ØV -> ØV
a :* as -> f a :* vmap f as
vap :: (f a -> g b -> h c) -> VT n f a -> VT n g b -> VT n h c
vap f = \case
ØV -> \_ -> ØV
a :* as -> \case
b :* bs -> f a b :* vap f as bs
_ -> error "impossible type"
vfoldr :: (f a -> b -> b) -> b -> VT n f a -> b
vfoldr s z = \case
ØV -> z
a :* as -> s a $ vfoldr s z as
vfoldMap' :: (b -> b -> b) -> b -> (f a -> b) -> VT n f a -> b
vfoldMap' j z f = \case
ØV -> z
a :* ØV -> f a
a :* as -> j (f a) $ vfoldMap' j z f as
vfoldMap :: Monoid m => (f a -> m) -> VT n f a -> m
vfoldMap f = \case
ØV -> mempty
a :* as -> f a <> vfoldMap f as
withVT :: [f a] -> (forall n. VT n f a -> r) -> r
withVT as k = case as of
[] -> k ØV
a : as' -> withVT as' $ \v -> k $ a :* v
withV :: [a] -> (forall n. V n a -> r) -> r
withV as k = withVT (I <$> as) k
findV :: Eq a => a -> V n a -> Maybe (Fin n)
findV = findVT . I
findVT :: Eq (f a) => f a -> VT n f a -> Maybe (Fin n)
findVT a = \case
ØV -> Nothing
b :* as -> if a == b
then Just FZ
else FS <$> findVT a as
instance Functor f => Functor (VT n f) where
fmap = vmap . fmap
instance (Applicative f, Known Nat n) => Applicative (VT n f) where
pure = vrep . pure
(<*>) = vap (<*>)
instance (Monad f, Known Nat n) => Monad (VT n f) where
v >>= f = imap (\x -> (>>= index x . f)) v
instance Foldable f => Foldable (VT n f) where
foldMap f = \case
ØV -> mempty
a :* as -> foldMap f a <> foldMap f as
instance Traversable f => Traversable (VT n f) where
traverse f = \case
ØV -> pure ØV
a :* as -> (:*) <$> traverse f a <*> traverse f as
instance Witness ØC (Known Nat n) (VT n f a) where
(\\) r = \case
ØV -> r
_ :* as -> r \\ as
instance (Num (f a), Known Nat n) => Num (VT n f a) where
(*) = vap (*)
(+) = vap (+)
() = vap ()
negate = vmap negate
abs = vmap abs
signum = vmap signum
fromInteger = vrep . fromInteger
newtype M ns a = M { getMatrix :: Matrix ns a }
deriving instance Eq (Matrix ns a) => Eq (M ns a)
deriving instance Ord (Matrix ns a) => Ord (M ns a)
deriving instance Show (Matrix ns a) => Show (M ns a)
instance Num (Matrix ns a) => Num (M ns a) where
fromInteger = M . fromInteger
M a * M b = M $ a * b
M a + M b = M $ a + b
M a M b = M $ a b
abs (M a) = M $ abs a
signum (M a) = M $ signum a
type family Matrix (ns :: [N]) :: * -> * where
Matrix Ø = I
Matrix (n :< ns) = VT n (Matrix ns)
vgen_ :: Known Nat n => (Fin n -> f a) -> VT n f a
vgen_ = vgen known
vgen :: Nat n -> (Fin n -> f a) -> VT n f a
vgen x f = case x of
Z_ -> ØV
S_ y -> f FZ :* vgen y (f . FS)
mgen_ :: Known (Prod Nat) ns => (Prod Fin ns -> a) -> M ns a
mgen_ = mgen known
mgen :: Prod Nat ns -> (Prod Fin ns -> a) -> M ns a
mgen ns f = case ns of
Ø -> M $ I $ f Ø
n :< ns' -> M $ vgen n $ getMatrix . mgen ns' . curry' f
onMatrix :: (Matrix ms a -> Matrix ns b) -> M ms a -> M ns b
onMatrix f = M . f . getMatrix
diagonal :: VT n (VT n f) a -> VT n f a
diagonal = imap index
vtranspose :: Known Nat n => VT m (VT n f) a -> VT n (VT m f) a
vtranspose v = vgen_ $ \x -> vmap (index x) v
transpose :: Known Nat n => M (m :< n :< ns) a -> M (n :< m :< ns) a
transpose = onMatrix vtranspose
m0 :: M Ø Int
m0 = 1
m1 :: M '[N2] Int
m1 = 2
m2 :: M '[N2,N4] Int
m2 = 3
m3 :: M '[N2,N3,N4] (Int,Int,Int)
m3 = mgen_ $ \(x :< y :> z) -> (fin x,fin y,fin z)
m4 :: M '[N2,N3,N4,N5] (Int,Int,Int,Int)
m4 = mgen_ $ \(w :< x :< y :> z) -> (fin w,fin x,fin y,fin z)
ppVec :: (VT n ((->) String) String -> ShowS) -> (f a -> ShowS) -> VT n f a -> ShowS
ppVec pV pF = pV . vmap pF
ppMatrix :: forall ns a. (Show a, Known Length ns) => M ns a -> IO ()
ppMatrix = putStrLn . ($ "") . ppMatrix' (known :: Length ns) . getMatrix
ppMatrix' :: Show a => Length ns -> Matrix ns a -> ShowS
ppMatrix' = \case
LZ -> shows . getI
LS l -> ppVec
( vfoldMap'
( if lEven l
then zipLines $ \x y -> x . showChar '|' . y
else \x y -> x . showChar '\n' . y
) (showString "[]") id
) $ ppMatrix' l
mzipWith :: Monoid a => (a -> a -> b) -> [a] -> [a] -> [b]
mzipWith f as bs = case (as,bs) of
([] ,[] ) -> []
(a:as',[] ) -> f a mempty : mzipWith f as' []
([] ,b:bs') -> f mempty b : mzipWith f [] bs'
(a:as',b:bs') -> f a b : mzipWith f as' bs'
zipLines :: (ShowS -> ShowS -> ShowS) -> ShowS -> ShowS -> ShowS
zipLines f a b = compose $ L.intersperse (showChar '\n') $ mzipWith
(\(Endo x) (Endo y) -> f x y)
(Endo . showString <$> lines (a ""))
(Endo . showString <$> lines (b ""))
compose :: Foldable f => f (a -> a) -> a -> a
compose = appEndo . foldMap Endo