{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralisedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Downhill.Internal.Graph.OpenGraph
  ( OpenEdge (..),
    OpenEndpoint (..),
    OpenNode (..),
    OpenGraph (..),
    recoverSharing,
  )
where

import Control.Monad.Trans.Class (lift)
import Control.Monad.Trans.State.Strict (StateT (..), get, modify)
import Downhill.Internal.Graph.OpenMap (OpenKey, OpenMap)
import qualified Downhill.Internal.Graph.OpenMap as OpenMap
import Downhill.Internal.Graph.Types (BackFun (BackFun))
import Downhill.Linear.Expr (BasicVector, Expr (ExprSum, ExprVar), Term (..))
import Prelude hiding (lookup)

data OpenEndpoint a v where
  OpenSourceNode :: OpenEndpoint a a
  OpenInnerNode :: OpenKey v -> OpenEndpoint a v

data OpenEdge a v where
  OpenEdge :: BackFun u v -> OpenEndpoint a u -> OpenEdge a v

data OpenNode a v = BasicVector v => OpenNode [OpenEdge a v]

-- | Maintains a cache of visited 'Expr's.
newtype TreeBuilder a r = TreeCache {forall a r. TreeBuilder a r -> StateT (OpenMap (OpenNode a)) IO r
unTreeCache :: StateT (OpenMap (OpenNode a)) IO r}
  deriving (forall a b. a -> TreeBuilder a b -> TreeBuilder a a
forall a b. (a -> b) -> TreeBuilder a a -> TreeBuilder a b
forall a a b. a -> TreeBuilder a b -> TreeBuilder a a
forall a a b. (a -> b) -> TreeBuilder a a -> TreeBuilder a b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> TreeBuilder a b -> TreeBuilder a a
$c<$ :: forall a a b. a -> TreeBuilder a b -> TreeBuilder a a
fmap :: forall a b. (a -> b) -> TreeBuilder a a -> TreeBuilder a b
$cfmap :: forall a a b. (a -> b) -> TreeBuilder a a -> TreeBuilder a b
Functor, forall a. Functor (TreeBuilder a)
forall a. a -> TreeBuilder a a
forall a a. a -> TreeBuilder a a
forall a b. TreeBuilder a a -> TreeBuilder a b -> TreeBuilder a a
forall a b. TreeBuilder a a -> TreeBuilder a b -> TreeBuilder a b
forall a b.
TreeBuilder a (a -> b) -> TreeBuilder a a -> TreeBuilder a b
forall a a b. TreeBuilder a a -> TreeBuilder a b -> TreeBuilder a a
forall a a b. TreeBuilder a a -> TreeBuilder a b -> TreeBuilder a b
forall a a b.
TreeBuilder a (a -> b) -> TreeBuilder a a -> TreeBuilder a b
forall a b c.
(a -> b -> c)
-> TreeBuilder a a -> TreeBuilder a b -> TreeBuilder a c
forall a a b c.
(a -> b -> c)
-> TreeBuilder a a -> TreeBuilder a b -> TreeBuilder a c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: forall a b. TreeBuilder a a -> TreeBuilder a b -> TreeBuilder a a
$c<* :: forall a a b. TreeBuilder a a -> TreeBuilder a b -> TreeBuilder a a
*> :: forall a b. TreeBuilder a a -> TreeBuilder a b -> TreeBuilder a b
$c*> :: forall a a b. TreeBuilder a a -> TreeBuilder a b -> TreeBuilder a b
liftA2 :: forall a b c.
(a -> b -> c)
-> TreeBuilder a a -> TreeBuilder a b -> TreeBuilder a c
$cliftA2 :: forall a a b c.
(a -> b -> c)
-> TreeBuilder a a -> TreeBuilder a b -> TreeBuilder a c
<*> :: forall a b.
TreeBuilder a (a -> b) -> TreeBuilder a a -> TreeBuilder a b
$c<*> :: forall a a b.
TreeBuilder a (a -> b) -> TreeBuilder a a -> TreeBuilder a b
pure :: forall a. a -> TreeBuilder a a
$cpure :: forall a a. a -> TreeBuilder a a
Applicative, forall a. Applicative (TreeBuilder a)
forall a. a -> TreeBuilder a a
forall a a. a -> TreeBuilder a a
forall a b. TreeBuilder a a -> TreeBuilder a b -> TreeBuilder a b
forall a b.
TreeBuilder a a -> (a -> TreeBuilder a b) -> TreeBuilder a b
forall a a b. TreeBuilder a a -> TreeBuilder a b -> TreeBuilder a b
forall a a b.
TreeBuilder a a -> (a -> TreeBuilder a b) -> TreeBuilder a b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> TreeBuilder a a
$creturn :: forall a a. a -> TreeBuilder a a
>> :: forall a b. TreeBuilder a a -> TreeBuilder a b -> TreeBuilder a b
$c>> :: forall a a b. TreeBuilder a a -> TreeBuilder a b -> TreeBuilder a b
>>= :: forall a b.
TreeBuilder a a -> (a -> TreeBuilder a b) -> TreeBuilder a b
$c>>= :: forall a a b.
TreeBuilder a a -> (a -> TreeBuilder a b) -> TreeBuilder a b
Monad)

insertIntoCache :: OpenKey dv -> OpenNode a dv -> TreeBuilder a ()
insertIntoCache :: forall dv a. OpenKey dv -> OpenNode a dv -> TreeBuilder a ()
insertIntoCache OpenKey dv
name OpenNode a dv
value = forall a r. StateT (OpenMap (OpenNode a)) IO r -> TreeBuilder a r
TreeCache forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify (forall (f :: * -> *) dx.
OpenKey dx -> f dx -> OpenMap f -> OpenMap f
OpenMap.insert OpenKey dv
name OpenNode a dv
value)

-- | @buildExpr action key@ will run @action@, associate result with @key@ and
-- store it in cache. If @key@ is already in cache, @action@ will not be run.
buildExpr ::
  TreeBuilder a (OpenNode a v) ->
  Expr a v ->
  TreeBuilder a (OpenKey v, OpenNode a v)
buildExpr :: forall a v.
TreeBuilder a (OpenNode a v)
-> Expr a v -> TreeBuilder a (OpenKey v, OpenNode a v)
buildExpr TreeBuilder a (OpenNode a v)
action Expr a v
key = do
  OpenKey v
name <- forall a r. StateT (OpenMap (OpenNode a)) IO r -> TreeBuilder a r
TreeCache (forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (forall (f :: * -> *) v. f v -> IO (OpenKey v)
OpenMap.makeOpenKey Expr a v
key))
  OpenMap (OpenNode a)
cache <- forall a r. StateT (OpenMap (OpenNode a)) IO r -> TreeBuilder a r
TreeCache forall (m :: * -> *) s. Monad m => StateT s m s
get
  case forall (f :: * -> *) x. OpenMap f -> OpenKey x -> Maybe (f x)
OpenMap.lookup OpenMap (OpenNode a)
cache OpenKey v
name of
    Just OpenNode a v
x -> forall (m :: * -> *) a. Monad m => a -> m a
return (OpenKey v
name, OpenNode a v
x)
    Maybe (OpenNode a v)
Nothing -> do
      OpenNode a v
value <- TreeBuilder a (OpenNode a v)
action
      forall dv a. OpenKey dv -> OpenNode a dv -> TreeBuilder a ()
insertIntoCache OpenKey v
name OpenNode a v
value
      forall (m :: * -> *) a. Monad m => a -> m a
return (OpenKey v
name, OpenNode a v
value)

runTreeBuilder :: forall a g dv. TreeBuilder a (g dv) -> IO (g dv, OpenMap (OpenNode a))
runTreeBuilder :: forall a (g :: * -> *) dv.
TreeBuilder a (g dv) -> IO (g dv, OpenMap (OpenNode a))
runTreeBuilder TreeBuilder a (g dv)
rs_x = forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (forall a r. TreeBuilder a r -> StateT (OpenMap (OpenNode a)) IO r
unTreeCache TreeBuilder a (g dv)
rs_x) forall (f :: * -> *). OpenMap f
OpenMap.empty

-- | Computational graph under construction. "Open" refers to the set of the nodes – new nodes can be
-- added to this graph. Once the graph is complete the set of nodes will be frozen
-- and the type of the graph will become 'Graph' ("Downhill.Internal.Graph" module).
data OpenGraph a z = OpenGraph (OpenNode a z) (OpenMap (OpenNode a))

goEdges :: BasicVector v => [Term a v] -> TreeBuilder a (OpenNode a v)
goEdges :: forall v a.
BasicVector v =>
[Term a v] -> TreeBuilder a (OpenNode a v)
goEdges [Term a v]
xs = do
  [OpenEdge a v]
xs' <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall a v. Term a v -> TreeBuilder a (OpenEdge a v)
goSharing4term [Term a v]
xs
  forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a v. BasicVector v => [OpenEdge a v] -> OpenNode a v
OpenNode [OpenEdge a v]
xs'

goSharing4arg :: forall a v. Expr a v -> TreeBuilder a (OpenEndpoint a v)
goSharing4arg :: forall a v. Expr a v -> TreeBuilder a (OpenEndpoint a v)
goSharing4arg Expr a v
key = case Expr a v
key of
  Expr a v
ExprVar -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a. OpenEndpoint a a
OpenSourceNode
  ExprSum [Term a v]
xs -> do
    (OpenKey v
gRef, OpenNode a v
_) <- forall a v.
TreeBuilder a (OpenNode a v)
-> Expr a v -> TreeBuilder a (OpenKey v, OpenNode a v)
buildExpr (forall v a.
BasicVector v =>
[Term a v] -> TreeBuilder a (OpenNode a v)
goEdges [Term a v]
xs) Expr a v
key
    forall (m :: * -> *) a. Monad m => a -> m a
return (forall v a. OpenKey v -> OpenEndpoint a v
OpenInnerNode OpenKey v
gRef)

goSharing4term :: forall a v. Term a v -> TreeBuilder a (OpenEdge a v)
goSharing4term :: forall a v. Term a v -> TreeBuilder a (OpenEdge a v)
goSharing4term = \case
  Term v -> VecBuilder u
f Expr a u
arg -> do
    OpenEndpoint a u
arg' <- forall a v. Expr a v -> TreeBuilder a (OpenEndpoint a v)
goSharing4arg Expr a u
arg
    forall (m :: * -> *) a. Monad m => a -> m a
return (forall u v a. BackFun u v -> OpenEndpoint a u -> OpenEdge a v
OpenEdge (forall u v. (v -> VecBuilder u) -> BackFun u v
BackFun v -> VecBuilder u
f) OpenEndpoint a u
arg')

-- | Collects duplicate nodes in 'Expr' tree and converts it to a graph.
recoverSharing :: forall a z. BasicVector z => [Term a z] -> IO (OpenGraph a z)
recoverSharing :: forall a z. BasicVector z => [Term a z] -> IO (OpenGraph a z)
recoverSharing [Term a z]
xs = do
  (OpenNode a z
final_node, OpenMap (OpenNode a)
graph) <- forall a (g :: * -> *) dv.
TreeBuilder a (g dv) -> IO (g dv, OpenMap (OpenNode a))
runTreeBuilder (forall v a.
BasicVector v =>
[Term a v] -> TreeBuilder a (OpenNode a v)
goEdges [Term a z]
xs)
  forall (m :: * -> *) a. Monad m => a -> m a
return (forall a z. OpenNode a z -> OpenMap (OpenNode a) -> OpenGraph a z
OpenGraph OpenNode a z
final_node OpenMap (OpenNode a)
graph)