{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
module Control.Category.Tensor.Expr
(
MConcat,
Tensored (..),
type (++),
AppendTensored (..),
)
where
import Control.Category.Tensor
import Data.Function
import Data.Kind
import Prelude (Eq, Ord, Show)
type MConcat :: (Type -> Type -> Type) -> Type -> [Type] -> Type
type family MConcat mappend mempty xs where
MConcat mappend mempty '[] = mempty
MConcat mappend mempty (x ': xs) = mappend x (MConcat mappend mempty xs)
newtype Tensored t i xs = Tensored { forall (t :: * -> * -> *) i (xs :: [*]).
Tensored t i xs -> MConcat t i xs
getTensored :: MConcat t i xs }
deriving newtype instance Show (MConcat t i xs) => Show (Tensored t i xs)
deriving newtype instance Eq (MConcat t i xs) => Eq (Tensored t i xs)
deriving newtype instance Ord (MConcat t i xs) => Ord (Tensored t i xs)
type (++) :: [k] -> [k] -> [k]
type family xs ++ ys
where
'[] ++ xs = xs
(x ': xs) ++ ys = x ': (xs ++ ys)
class AppendTensored xs where
appendTensored :: Tensor (->) t i => Tensored t i xs `t` Tensored t i ys -> Tensored t i (xs ++ ys)
instance AppendTensored '[]
where
appendTensored :: forall (t :: * -> * -> *) i (ys :: [*]).
Tensor (->) t i =>
t (Tensored t i '[]) (Tensored t i ys) -> Tensored t i ('[] ++ ys)
appendTensored = forall (cat :: * -> * -> *) a b. Iso cat a b -> cat a b
fwd forall (cat :: * -> * -> *) (t :: * -> * -> *) i a.
Tensor cat t i =>
Iso cat (t i a) a
unitl forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (cat1 :: * -> * -> *) (cat2 :: * -> * -> *)
(cat3 :: * -> * -> *) (t :: * -> * -> *) a b c.
GBifunctor cat1 cat2 cat3 t =>
cat1 a b -> cat3 (t a c) (t b c)
glmap forall (t :: * -> * -> *) i (xs :: [*]).
Tensored t i xs -> MConcat t i xs
getTensored
instance AppendTensored xs => AppendTensored (x ': xs)
where
appendTensored :: forall (t :: * -> * -> *) i (ys :: [*]).
Tensor (->) t i =>
t (Tensored t i (x : xs)) (Tensored t i ys)
-> Tensored t i ((x : xs) ++ ys)
appendTensored = forall (t :: * -> * -> *) i (xs :: [*]).
MConcat t i xs -> Tensored t i xs
Tensored forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (cat1 :: * -> * -> *) (cat2 :: * -> * -> *)
(cat3 :: * -> * -> *) (t :: * -> * -> *) c d a.
GBifunctor cat1 cat2 cat3 t =>
cat2 c d -> cat3 (t a c) (t a d)
grmap (forall (t :: * -> * -> *) i (xs :: [*]).
Tensored t i xs -> MConcat t i xs
getTensored forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (xs :: [*]) (t :: * -> * -> *) i (ys :: [*]).
(AppendTensored xs, Tensor (->) t i) =>
t (Tensored t i xs) (Tensored t i ys) -> Tensored t i (xs ++ ys)
appendTensored @xs forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (cat1 :: * -> * -> *) (cat2 :: * -> * -> *)
(cat3 :: * -> * -> *) (t :: * -> * -> *) a b c.
GBifunctor cat1 cat2 cat3 t =>
cat1 a b -> cat3 (t a c) (t b c)
glmap forall (t :: * -> * -> *) i (xs :: [*]).
MConcat t i xs -> Tensored t i xs
Tensored) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (cat :: * -> * -> *) a b. Iso cat a b -> cat b a
bwd forall (cat :: * -> * -> *) (t :: * -> * -> *) a b c.
Associative cat t =>
Iso cat (t a (t b c)) (t (t a b) c)
assoc forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (cat1 :: * -> * -> *) (cat2 :: * -> * -> *)
(cat3 :: * -> * -> *) (t :: * -> * -> *) a b c.
GBifunctor cat1 cat2 cat3 t =>
cat1 a b -> cat3 (t a c) (t b c)
glmap forall (t :: * -> * -> *) i (xs :: [*]).
Tensored t i xs -> MConcat t i xs
getTensored