{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}

module Control.Category.Tensor.Expr
  ( -- * Type Families
    MConcat,
    Tensored (..),
    type (++),

    -- * AppendTensored
    AppendTensored (..),
  )
where

import Control.Category.Tensor
import Data.Function
import Data.Kind
import Prelude (Eq, Ord, Show)

--------------------------------------------------------------------------------

-- |
--
-- __Examples:__
--
-- >>> :{
--  let foo :: Tensored (,) () '[Bool, Int]
--      foo = Tensored (True, (8, ()))
-- :}
--
-- >>> :{
-- let bar :: Tensored Either Void '[Bool, Int]
--     bar = Tensored $ Right $ Left 8
-- :}
--
-- >>> :{
-- let baz :: Tensored These Void '[Bool, Int]
--     baz = Tensored $ These True $ This 8
-- :}
type MConcat :: (Type -> Type -> Type) -> Type -> [Type] -> Type
type family MConcat mappend mempty xs where
  MConcat mappend mempty '[] = mempty
  MConcat mappend mempty (x ': xs) = mappend x (MConcat mappend mempty xs)

newtype Tensored t i xs = Tensored {forall (t :: * -> * -> *) i (xs :: [*]).
Tensored t i xs -> MConcat t i xs
getTensored :: MConcat t i xs}

deriving newtype instance Show (MConcat t i xs) => Show (Tensored t i xs)

deriving newtype instance Eq (MConcat t i xs) => Eq (Tensored t i xs)

deriving newtype instance Ord (MConcat t i xs) => Ord (Tensored t i xs)

--------------------------------------------------------------------------------

type (++) :: [k] -> [k] -> [k]
type family xs ++ ys where
  '[] ++ xs = xs
  (x ': xs) ++ ys = x ': (xs ++ ys)

class AppendTensored xs where
  appendTensored :: Tensor (->) t i => Tensored t i xs `t` Tensored t i ys -> Tensored t i (xs ++ ys)

instance AppendTensored '[] where
  appendTensored :: forall (t :: * -> * -> *) i (ys :: [*]).
Tensor (->) t i =>
t (Tensored t i '[]) (Tensored t i ys) -> Tensored t i ('[] ++ ys)
appendTensored = forall (cat :: * -> * -> *) a b. Iso cat a b -> cat a b
fwd forall (cat :: * -> * -> *) (t :: * -> * -> *) i a.
Tensor cat t i =>
Iso cat (t i a) a
unitl forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (cat1 :: * -> * -> *) (cat2 :: * -> * -> *)
       (cat3 :: * -> * -> *) (t :: * -> * -> *) a b c.
GBifunctor cat1 cat2 cat3 t =>
cat1 a b -> cat3 (t a c) (t b c)
glmap forall (t :: * -> * -> *) i (xs :: [*]).
Tensored t i xs -> MConcat t i xs
getTensored

instance AppendTensored xs => AppendTensored (x ': xs) where
  appendTensored :: forall (t :: * -> * -> *) i (ys :: [*]).
Tensor (->) t i =>
t (Tensored t i (x : xs)) (Tensored t i ys)
-> Tensored t i ((x : xs) ++ ys)
appendTensored = forall (t :: * -> * -> *) i (xs :: [*]).
MConcat t i xs -> Tensored t i xs
Tensored forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (cat1 :: * -> * -> *) (cat2 :: * -> * -> *)
       (cat3 :: * -> * -> *) (t :: * -> * -> *) c d a.
GBifunctor cat1 cat2 cat3 t =>
cat2 c d -> cat3 (t a c) (t a d)
grmap (forall (t :: * -> * -> *) i (xs :: [*]).
Tensored t i xs -> MConcat t i xs
getTensored forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (xs :: [*]) (t :: * -> * -> *) i (ys :: [*]).
(AppendTensored xs, Tensor (->) t i) =>
t (Tensored t i xs) (Tensored t i ys) -> Tensored t i (xs ++ ys)
appendTensored @xs forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (cat1 :: * -> * -> *) (cat2 :: * -> * -> *)
       (cat3 :: * -> * -> *) (t :: * -> * -> *) a b c.
GBifunctor cat1 cat2 cat3 t =>
cat1 a b -> cat3 (t a c) (t b c)
glmap forall (t :: * -> * -> *) i (xs :: [*]).
MConcat t i xs -> Tensored t i xs
Tensored) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (cat :: * -> * -> *) a b. Iso cat a b -> cat b a
bwd forall (cat :: * -> * -> *) (t :: * -> * -> *) a b c.
Associative cat t =>
Iso cat (t a (t b c)) (t (t a b) c)
assoc forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (cat1 :: * -> * -> *) (cat2 :: * -> * -> *)
       (cat3 :: * -> * -> *) (t :: * -> * -> *) a b c.
GBifunctor cat1 cat2 cat3 t =>
cat1 a b -> cat3 (t a c) (t b c)
glmap forall (t :: * -> * -> *) i (xs :: [*]).
Tensored t i xs -> MConcat t i xs
getTensored