{-# LANGUAGE DataKinds #-}
{-# LANGUAGE PolyKinds #-}
module Data.Pattern.Base.Tuple (
Fun,
Tuple,
zeroT,
oneT,
(<+>),
runTuple,
Map, Distribute(..)
) where
import Data.Pattern.Base.Difference
import Data.Pattern.Base.TypeList
import Data.Kind (Type)
type family Fun (xs :: [Type]) r
type instance Fun '[] r = r
type instance Fun (h ': t) r = h -> Fun t r
data family Tup (xs :: [Type])
data instance Tup '[] = Unit
data instance Tup (h ': t) = Pair h (Tup t)
class Uncurriable xs where
uncurryT :: (Tup xs -> r) -> Fun xs r
instance Uncurriable '[] where
uncurryT f = f Unit
instance Uncurriable t => Uncurriable (h ': t) where
uncurryT f = \h -> uncurryT (\tup -> f (Pair h tup))
newtype Tuple' xs = Tuple' { runTuple' :: forall r. Fun xs r -> r }
newtype Tuple xs = Tuple (D Tuple' xs)
zeroT :: Tuple '[]
zeroT = Tuple zeroD
oneT :: a -> Tuple '[a]
oneT a = Tuple (mkOneD (\(Tuple' t) -> Tuple' (\k -> t (k a))))
class Tupable xs where
mkTuple :: Tup xs -> Tuple xs
instance Tupable '[] where
mkTuple Unit = zeroT
instance Tupable t => Tupable (h ': t) where
mkTuple (Pair h t) = oneT h <+> mkTuple t
(<+>) :: Tuple xs -> Tuple ys -> Tuple (xs :++: ys)
Tuple xs <+> Tuple ys = Tuple (xs `plusD` ys)
runTuple :: Tuple xs -> Fun xs r -> r
runTuple (Tuple t) = runTuple' (evalD (Tuple' id) t)
runTupleT :: Uncurriable xs => Tuple xs -> (Tup xs -> r) -> r
runTupleT t f = runTuple t (uncurryT f)
unconsTuple :: (Uncurriable t, Tupable t) => Tuple (h ': t) -> (h, Tuple t)
unconsTuple t = runTupleT t (\(Pair h t) -> (h, mkTuple t))
tupleHead :: (Uncurriable t, Tupable t) => Tuple (h ': t) -> h
tupleHead = fst . unconsTuple
tupleTail :: (Uncurriable t, Tupable t) => Tuple (h ': t) -> Tuple t
tupleTail = snd . unconsTuple
type family Map (f :: Type -> Type) (xs :: [Type]) :: [Type]
type instance Map f '[] = '[]
type instance Map f (h ': t) = f h ': Map f t
class Distribute xs where
distribute :: Functor f => f (Tuple xs) -> Tuple (Map f xs)
instance Distribute '[] where
distribute _ = zeroT
instance (Uncurriable t, Tupable t, Distribute t) => Distribute (h ': t) where
distribute f = oneT (fmap tupleHead f) <+> distribute (fmap tupleTail f)