{-# LANGUAGE CPP #-}
{-# LANGUAGE Safe #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE RebindableSyntax #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE RankNTypes #-}
module Data.Semimodule.Transform where
import safe Control.Category (Category, (>>>))
import safe Data.Functor.Compose
import safe Data.Functor.Product
import safe Data.Functor.Rep
import safe Data.Profunctor
import safe Data.Semimodule
import safe Data.Tuple (swap)
import safe Prelude hiding (Num(..), Fractional(..), negate, sum, product)
import safe Test.Logic
import safe qualified Control.Category as C
import safe qualified Data.Bifunctor as B
type Index b c = forall a . Tran a b c
newtype Tran a b c = Tran { runTran :: (c -> a) -> (b -> a) } deriving Functor
app :: Basis b f => Basis c g => Tran a b c -> g a -> f a
app t = tabulate . runTran t . index
instance Category (Tran a) where
id = Tran id
Tran f . Tran g = Tran $ g . f
instance Profunctor (Tran a) where
lmap f (Tran t) = Tran $ \ca -> t ca . f
rmap = fmap
invmap :: (a1 -> a2) -> (a2 -> a1) -> Tran a1 b c -> Tran a2 b c
invmap f g (Tran t) = Tran $ \x -> t (x >>> g) >>> f
type Endo a b = Tran a b b
rows :: Free f => Free g => g a -> f (g a)
rows = getCompose . app in1
{-# INLINE rows #-}
cols :: Free f => Free g => f a -> f (g a)
cols = getCompose . app in2
{-# INLINE cols #-}
projl :: Free f => Free g => Product f g a -> f a
projl = app exl
projr :: Free f => Free g => Product f g a -> g a
projr = app exr
compl :: Basis b f1 => Basis c f2 => Free g => Index b c -> f2 (g a) -> f1 (g a)
compl f = getCompose . app (first f) . Compose
compr :: Basis b g1 => Basis c g2 => Free f => Index b c -> f (g2 a) -> f (g1 a)
compr f = getCompose . app (second f) . Compose
complr :: Basis b1 f1 => Basis c1 f2 => Basis b2 g1 => Basis c2 g2 => Index b1 c1 -> Index b2 c2 -> f2 (g2 a) -> f1 (g1 a)
complr f g = getCompose . app (f *** g) . Compose
transpose :: Free f => Free g => f (g a) -> g (f a)
transpose = getCompose . app braid . Compose
{-# INLINE transpose #-}
arr :: (b -> c) -> Index b c
arr f = Tran (. f)
in1 :: Index (a , b) b
in1 = arr snd
{-# INLINE in1 #-}
in2 :: Index (a , b) a
in2 = arr fst
{-# INLINE in2 #-}
exl :: Index a (a + b)
exl = arr Left
{-# INLINE exl #-}
exr :: Index b (a + b)
exr = arr Right
{-# INLINE exr #-}
braid :: Index (a , b) (b , a)
braid = arr swap
{-# INLINE braid #-}
ebraid :: Index (a + b) (b + a)
ebraid = arr eswap
{-# INLINE ebraid #-}
first :: Index b c -> Index (b , d) (c , d)
first (Tran caba) = Tran $ \cda -> cda . B.first (caba id)
second :: Index b c -> Index (d , b) (d , c)
second (Tran caba) = Tran $ \cda -> cda . B.second (caba id)
left :: Index b c -> Index (b + d) (c + d)
left (Tran caba) = Tran $ \cda -> cda . B.first (caba id)
right :: Index b c -> Index (d + b) (d + c)
right (Tran caba) = Tran $ \cda -> cda . B.second (caba id)
infixr 3 ***
(***) :: Index a1 b1 -> Index a2 b2 -> Index (a1 , a2) (b1 , b2)
x *** y = first x >>> arr swap >>> first y >>> arr swap
{-# INLINE (***) #-}
infixr 2 +++
(+++) :: Index a1 b1 -> Index a2 b2 -> Index (a1 + a2) (b1 + b2)
x +++ y = left x >>> arr eswap >>> left y >>> arr eswap
{-# INLINE (+++) #-}
infixr 3 &&&
(&&&) :: Index a b1 -> Index a b2 -> Index a (b1 , b2)
x &&& y = dimap fork id $ x *** y
{-# INLINE (&&&) #-}
infixr 2 |||
(|||) :: Index a1 b -> Index a2 b -> Index (a1 + a2) b
x ||| y = dimap id join $ x +++ y
{-# INLINE (|||) #-}
infixr 0 $$$
($$$) :: Index a (b -> c) -> Index a b -> Index a c
($$$) f x = dimap fork apply (f *** x)
{-# INLINE ($$$) #-}
adivide :: (a -> (a1 , a2)) -> Index a1 b -> Index a2 b -> Index a b
adivide f x y = dimap f fst $ x *** y
{-# INLINE adivide #-}
adivide' :: Index a b -> Index a b -> Index a b
adivide' = adivide fork
{-# INLINE adivide' #-}
adivided :: Index a1 b -> Index a2 b -> Index (a1 , a2) b
adivided = adivide id
{-# INLINE adivided #-}
aselect :: ((b1 + b2) -> b) -> Index a b1 -> Index a b2 -> Index a b
aselect f x y = dimap Left f $ x +++ y
{-# INLINE aselect #-}
aselect' :: Index a b -> Index a b -> Index a b
aselect' = aselect join
{-# INLINE aselect' #-}
aselected :: Index a b1 -> Index a b2 -> Index a (b1 + b2)
aselected = aselect id
{-# INLINE aselected #-}