module Naperian where
import qualified Prelude
import Prelude hiding ( lookup, length, replicate, zipWith )
import qualified Data.IntMap as IntMap
import Data.List ( intercalate )
import Data.Kind ( Type, Constraint )
import Control.Applicative ( liftA2 )
import qualified GHC.Exts as L (IsList(..))
import GHC.Prim
import GHC.TypeLits
import qualified Data.Vector as Vector
import Data.Foldable ( toList )
data Finite :: Nat -> Type where
Fin :: Int -> Finite n
deriving (Eq, Show)
finite :: forall n. KnownNat n => Int -> Maybe (Finite n)
finite x = case (x > y) of
True -> Nothing
False -> Just (Fin x)
where y = fromIntegral (natVal' (proxy# :: Proxy# n))
azipWith :: Applicative f => (a -> b -> c) -> f a -> f b -> f c
azipWith h xs ys = (pure h <*> xs) <*> ys
showVector :: [String] -> String
showVector xs = "<" ++ intercalate "," xs ++ ">"
data Pair a = Pair a a
deriving (Show, Eq, Ord, Functor, Foldable, Traversable)
instance Applicative Pair where
pure a = Pair a a
Pair k g <*> Pair a b = Pair (k a) (g b)
newtype Vector (n :: Nat) a = Vector (Vector.Vector a)
deriving (Eq, Ord, Functor, Foldable, Traversable)
instance Show a => Show (Vector n a) where
show = showVector . map show . toList
instance KnownNat n => Applicative (Vector n) where
pure = replicate
(<*>) = zipWith ($)
instance (KnownNat n, Traversable (Vector n)) => L.IsList (Vector n a) where
type Item (Vector n a) = a
toList = Data.Foldable.toList
fromList xs = case fromList xs of
Nothing -> error "Demanded vector of a list that wasn't the proper length"
Just ys -> ys
tail :: Vector (n + 1) a -> Vector n a
tail (Vector v) = Vector (Vector.tail v)
fromList :: forall n a. KnownNat n => [a] -> Maybe (Vector n a)
fromList xs = case (Prelude.length xs == sz) of
False -> Nothing
True -> Just (Vector $ Vector.fromList xs)
where sz = fromIntegral (natVal' (proxy# :: Proxy# n)) :: Int
zipWith :: (a -> b -> c) -> Vector n a -> Vector n b -> Vector n c
zipWith f (Vector a) (Vector b) = Vector (Vector.zipWith f a b)
length :: forall n a. KnownNat n => Vector n a -> Int
length _ = fromIntegral $ natVal' (proxy# :: Proxy# n)
replicate :: forall n a. KnownNat n => a -> Vector n a
replicate v = Vector (Vector.replicate sz v) where
sz = fromIntegral (natVal' (proxy# :: Proxy# n)) :: Int
index :: Vector n a -> Finite n -> a
index (Vector v) (Fin n) = (Vector.!) v n
viota :: forall n. KnownNat n => Vector n (Finite n)
viota = Vector (fmap Fin (Vector.enumFromN 0 sz)) where
sz = fromIntegral (natVal' (proxy# :: Proxy# n)) :: Int
class Functor f => Naperian f where
type Log f
lookup :: f a -> (Log f -> a)
tabulate :: (Log f -> a) -> f a
tabulate h = fmap h positions
positions :: f (Log f)
positions = tabulate id
transpose :: (Naperian f, Naperian g) => f (g a) -> g (f a)
transpose = tabulate . fmap tabulate . flip . fmap lookup . lookup
instance Naperian Pair where
type Log Pair = Bool
lookup (Pair x y) b = if b then y else x
positions = Pair False True
instance KnownNat n => Naperian (Vector n) where
type Log (Vector n) = Finite n
lookup = index
positions = viota
class (Applicative f, Naperian f, Traversable f) => Dimension f where
size :: f a -> Int
size = Prelude.length . toList
instance Dimension Pair where size = const 2
instance KnownNat n => Dimension (Vector n) where size = length
inner :: (Num a, Dimension f) => f a -> f a -> a
inner xs ys = sum (liftA2 (*) xs ys)
matrix :: ( Num a
, Dimension f
, Dimension g
, Dimension h
) => f (g a)
-> g (h a)
-> f (h a)
matrix xss yss = liftA2 (liftA2 inner) (fmap pure xss) (pure (transpose yss))
data Hyper :: [Type -> Type] -> Type -> Type where
Scalar :: a -> Hyper '[] a
Prism :: (Dimension f, Shapely fs) => Hyper fs (f a) -> Hyper (f : fs) a
point :: Hyper '[] a -> a
point (Scalar a) = a
crystal :: Hyper (f : fs) a -> Hyper fs (f a)
crystal (Prism x) = x
instance Show a => Show (Hyper fs a) where
show = showHyper . fmap show where
showHyper :: Hyper gs String -> String
showHyper (Scalar s) = s
showHyper (Prism x) = showHyper (fmap (showVector . toList) x)
class Shapely fs where
hreplicate :: a -> Hyper fs a
hsize :: Hyper fs a -> Int
instance Shapely '[] where
hreplicate a = Scalar a
hsize = const 1
instance (Dimension f, Shapely fs) => Shapely (f : fs) where
hreplicate a = Prism (hreplicate (pure a))
hsize (Prism x) = size (first x) * hsize x
instance Functor (Hyper fs) where
fmap f (Scalar a) = Scalar (f a)
fmap f (Prism x) = Prism (fmap (fmap f) x)
instance Shapely fs => Applicative (Hyper fs) where
pure = hreplicate
(<*>) = hzipWith ($)
hzipWith :: (a -> b -> c) -> Hyper fs a -> Hyper fs b -> Hyper fs c
hzipWith f (Scalar a) (Scalar b) = Scalar (f a b)
hzipWith f (Prism x) (Prism y) = Prism (hzipWith (azipWith f) x y)
first :: Shapely fs => Hyper fs a -> a
first (Scalar a) = a
first (Prism x) = head (toList (first x))
transposeH :: Hyper (f : (g : fs)) a
-> Hyper (g : (f : fs)) a
transposeH (Prism (Prism x)) = Prism (Prism (fmap transpose x))
foldrH :: (a -> a -> a) -> a -> Hyper (f : fs) a -> Hyper fs a
foldrH f z (Prism x) = fmap (foldr f z) x
unary :: Shapely fs => (a -> b) -> (Hyper fs a -> Hyper fs b)
unary = fmap
binary :: ( Compatible fs gs
, Max fs gs ~ hs
, Alignable fs hs
, Alignable gs hs
) => (a -> b -> c)
-> Hyper fs a
-> Hyper gs b
-> Hyper hs c
binary f x y = hzipWith f (align x) (align y)
up :: (Shapely fs, Dimension f) => Hyper fs a -> Hyper (f : fs) a
up = Prism . fmap pure
innerH :: ( Max fs gs ~ (f : hs)
, Alignable fs (f : hs)
, Alignable gs (f : hs)
, Compatible fs gs
, Num a
) => Hyper fs a
-> Hyper gs a
-> Hyper hs a
innerH xs ys = foldrH (+) 0 (binary (*) xs ys)
matrixH :: ( Num a
, Dimension f
, Dimension g
, Dimension h
) => Hyper '[ g, f ] a
-> Hyper '[ h, g ] a
-> Hyper '[ h, f ] a
matrixH x y = case (crystal x, transposeH y) of
(xs, Prism (Prism ys)) -> hzipWith inner (up xs) (Prism (up ys))
class (Shapely fs, Shapely gs) => Alignable fs gs where
align :: Hyper fs a -> Hyper gs a
instance Alignable '[] '[] where
align = id
instance (Dimension f, Alignable fs gs) => Alignable (f : fs) (f : gs) where
align (Prism x) = Prism (align x)
instance (Dimension f, Shapely fs) => Alignable '[] (f : fs) where
align (Scalar a) = hreplicate a
type family Max (fs :: [Type -> Type]) (gs :: [Type -> Type]) :: [Type -> Type] where
Max '[] '[] = '[]
Max '[] (f : gs) = f : gs
Max (f : fs) '[] = f : fs
Max (f : fs) (f : gs) = f : Max fs gs
type family Compatible (fs :: [Type -> Type]) (gs :: [Type -> Type]) :: Constraint where
Compatible '[] '[] = ()
Compatible '[] (f : gs) = ()
Compatible (f : fs) '[] = ()
Compatible (f : fs) (f : gs) = Compatible fs gs
Compatible a b = TypeError (
'Text "Mismatched dimensions!"
':$$: 'Text "The dimension " ':<>: 'ShowType a ':<>: 'Text " can't be aligned with"
':$$: 'Text "the dimension " ':<>: 'ShowType b)
elements :: Shapely fs => Hyper fs a -> [a]
elements (Scalar a) = [a]
elements (Prism a) = concat (map toList (elements a))
data Flat fs a where
Flat :: Shapely fs => Vector.Vector a -> Flat fs a
instance Functor (Flat fs) where
fmap f (Flat v) = Flat (fmap f v)
instance Show a => Show (Flat fs a) where
show = showHyper . fmap show where
showHyper :: Flat gs String -> String
showHyper (Flat v) = showVector (toList v)
flatten :: Shapely fs => Hyper fs a -> Flat fs a
flatten hs = Flat (Vector.fromList (elements hs))
data Sparse fs a where
Sparse :: Shapely fs => a -> IntMap.IntMap a -> Sparse fs a
unsparse :: forall fs a. Shapely fs => Sparse fs a -> Flat fs a
unsparse (Sparse e xs) = Flat (Vector.unsafeAccum (flip const) vs as)
where
as = IntMap.assocs xs
vs = Vector.replicate l e
l = hsize (hreplicate () :: Hyper fs ())
type Matrix n m v = Vector n (Vector m v)
example1 :: Int
example1 = inner v1 v2 where
v1 = [ 1, 2, 3 ] :: Vector 3 Int
v2 = [ 4, 5, 6 ] :: Vector 3 Int
example2 :: Matrix 2 2 Int
example2 = matrix m1 m2 where
m1 = [ [ 1, 2, 3 ]
, [ 4, 5, 6 ]
] :: Matrix 2 3 Int
m2 = [ [ 9, 8 ]
, [ 6, 5 ]
, [ 3, 2 ]
] :: Matrix 3 2 Int
example3 :: Hyper '[] Int
example3 = innerH v1 v2 where
v1 = Prism (Scalar [1, 2, 3]) :: Hyper '[Vector 3] Int
v2 = Prism (Scalar [4, 5, 6]) :: Hyper '[Vector 3] Int
example4 :: Hyper '[Vector 2, Vector 2] Int
example4 = matrixH v1 v2 where
x = [ [ 1, 2, 3 ]
, [ 4, 5, 6 ]
] :: Matrix 2 3 Int
y = [ [ 9, 8 ]
, [ 6, 5 ]
, [ 3, 2 ]
] :: Matrix 3 2 Int
v1 = Prism (Prism (Scalar x)) :: Hyper '[Vector 3, Vector 2] Int
v2 = Prism (Prism (Scalar y)) :: Hyper '[Vector 2, Vector 3] Int