module Language.Syntactic.Sharing.Graph where
import Control.Arrow ((***))
import Control.Monad.Reader
import Data.Array
import Data.Function
import Data.List
import Data.Typeable
import Data.Hash
import Language.Syntactic
import Language.Syntactic.Constructs.Binding
import Language.Syntactic.Sharing.Utils
newtype NodeId = NodeId { nodeInteger :: Integer }
deriving (Eq, Ord, Num, Real, Integral, Enum, Ix)
instance Show NodeId
where
show (NodeId i) = show i
showNode :: NodeId -> String
showNode n = "node:" ++ show n
data Node a
where
Node :: NodeId -> Node (Full a)
instance Constrained Node
where
type Sat Node = Top
exprDict _ = Dict
instance Render Node
where
renderSym (Node a) = showNode a
instance StringTree Node
class NodeEqEnv dom a
where
prjNodeEqEnv :: a -> NodeEnv dom (Sat dom)
modNodeEqEnv :: (NodeEnv dom (Sat dom) -> NodeEnv dom (Sat dom)) -> (a -> a)
type EqEnv dom p = ([(VarId,VarId)], NodeEnv dom p)
type NodeEnv dom p =
( Array NodeId Hash
, Array NodeId (ASTB dom p)
)
instance (p ~ Sat dom) => NodeEqEnv dom (EqEnv dom p)
where
prjNodeEqEnv = snd
modNodeEqEnv f = (id *** f)
instance VarEqEnv (EqEnv dom p)
where
prjVarEqEnv = fst
modVarEqEnv f = (f *** id)
instance (AlphaEq dom dom dom env, NodeEqEnv dom env) =>
AlphaEq Node Node dom env
where
alphaEqSym (Node n1) Nil (Node n2) Nil
| n1 == n2 = return True
| otherwise = do
(hTab,nTab) :: NodeEnv dom (Sat dom) <- asks prjNodeEqEnv
if hTab!n1 /= hTab!n2
then return False
else case (nTab!n1, nTab!n2) of
(ASTB a, ASTB b) -> alphaEqM a b
data ASG dom a = ASG
{ topExpression :: ASTF (NodeDomain dom) a
, graphNodes :: [(NodeId, ASTSAT (NodeDomain dom))]
, numNodes :: NodeId
}
type NodeDomain dom = (Node :+: dom) :|| Sat dom
showASG :: forall dom a. StringTree dom => ASG dom a -> String
showASG (ASG top nodes _) =
unlines ((line "top" ++ showAST top) : map showNode nodes)
where
line str = "---- " ++ str ++ " " ++ rest ++ "\n"
where
rest = replicate (40 length str) '-'
showNode :: (NodeId, ASTSAT (NodeDomain dom)) -> String
showNode (n, ASTB expr) = concat
[ line ("node:" ++ show n)
, showAST expr
]
drawASG :: StringTree dom => ASG dom a -> IO ()
drawASG = putStrLn . showASG
reindexNodesAST ::
(NodeId -> NodeId) -> AST (NodeDomain dom) a -> AST (NodeDomain dom) a
reindexNodesAST reix (Sym (C' (InjL (Node n)))) = injC $ Node $ reix n
reindexNodesAST reix (s :$ a) = reindexNodesAST reix s :$ reindexNodesAST reix a
reindexNodesAST reix a = a
reindexNodes :: (NodeId -> NodeId) -> ASG dom a -> ASG dom a
reindexNodes reix (ASG top nodes n) = ASG top' nodes' n
where
top' = reindexNodesAST reix top
nodes' =
[ (reix n, ASTB $ reindexNodesAST reix a)
| (n, ASTB a) <- nodes
]
reindexNodesFrom0 :: ASG dom a -> ASG dom a
reindexNodesFrom0 graph = reindexNodes reix graph
where
reix = reindex $ map fst $ graphNodes graph
nubNodes :: ASG dom a -> ASG dom a
nubNodes (ASG top nodes n) = ASG top nodes' n'
where
nodes' = nubBy ((==) `on` fst) nodes
n' = genericLength nodes'
data SyntaxPF dom a
where
AppPF :: a -> a -> SyntaxPF dom a
NodePF :: NodeId -> a -> SyntaxPF dom a
DomPF :: dom b -> SyntaxPF dom a
instance Functor (SyntaxPF dom)
where
fmap f (AppPF g a) = AppPF (f g) (f a)
fmap f (NodePF n a) = NodePF n (f a)
fmap f (DomPF a) = DomPF a
foldGraph :: forall dom a b .
(SyntaxPF dom b -> b) -> ASG dom a -> (b, (Array NodeId b, [(NodeId,b)]))
foldGraph alg (ASG top ns nn) = (g top, (arr,nodes))
where
nodes = [(n, g expr) | (n, ASTB expr) <- ns]
arr = array (0, nn1) nodes
g :: AST (NodeDomain dom) c -> b
g (h :$ a) = alg $ AppPF (g h) (g a)
g (Sym (C' (InjL (Node n)))) = alg $ NodePF n (arr!n)
g (Sym (C' (InjR a))) = alg $ DomPF a
inlineAll :: forall dom a . ConstrainedBy dom Typeable =>
ASG dom a -> ASTF dom a
inlineAll (ASG top nodes n) = inline top
where
nodeMap = array (0, n1) nodes
inline :: AST (NodeDomain dom) b -> AST dom b
inline (s :$ a) = inline s :$ inline a
inline s@(Sym (C' (InjL (Node n)))) = case nodeMap ! n of
ASTB a
| Dict <- exprDictSub pTypeable s
, Dict <- exprDictSub pTypeable a
-> case gcast a of
Nothing -> error "inlineAll: type mismatch"
Just a -> inline a
inline (Sym (C' (InjR a))) = Sym a
nodeChildren :: ASG dom a -> [(NodeId, [NodeId])]
nodeChildren = map (id *** fromDList) . snd . snd . foldGraph children
where
children :: SyntaxPF dom (DList NodeId) -> DList NodeId
children (AppPF ns1 ns2) = ns1 . ns2
children (NodePF n _) = single n
children _ = empty
occurrences :: ASG dom a -> Array NodeId Int
occurrences graph
= count (0, numNodes graph 1)
$ concatMap snd
$ nodeChildren graph
inlineSingle :: forall dom a . ConstrainedBy dom Typeable =>
ASG dom a -> ASG dom a
inlineSingle graph@(ASG top nodes n) = ASG top' nodes' n'
where
nodeTab = array (0, n1) nodes
occs = occurrences graph
top' = inline top
nodes' = [(n, ASTB (inline a)) | (n, ASTB a) <- nodes, occs!n > 1]
n' = genericLength nodes'
inline :: AST (NodeDomain dom) b -> AST (NodeDomain dom) b
inline (s :$ a) = inline s :$ inline a
inline s@(Sym (C' (InjL (Node n))))
| occs!n > 1 = injC $ Node n
| otherwise = case nodeTab ! n of
ASTB a
| Dict <- exprDictSub pTypeable s
, Dict <- exprDictSub pTypeable a
-> case gcast a of
Nothing -> error "inlineSingle: type mismatch"
Just a -> inline a
inline (Sym (C' (InjR a))) = Sym $ C' $ InjR a
hashNodes :: Equality dom => ASG dom a -> (Array NodeId Hash, [(NodeId, Hash)])
hashNodes = snd . foldGraph hashNode
where
hashNode (AppPF h1 h2) = hashInt 0 `combine` h1 `combine` h2
hashNode (NodePF _ h) = h
hashNode (DomPF a) = hashInt 1 `combine` exprHash a
partitionNodes :: forall dom a
. ( Equality dom
, AlphaEq dom dom (NodeDomain dom) (EqEnv (NodeDomain dom) (Sat dom))
)
=> ASG dom a -> [[NodeId]]
partitionNodes graph = concatMap (fullPartition nodeEq) approxPartitioning
where
nTab = array (0, numNodes graph 1) (graphNodes graph)
(hTab,hashes) = hashNodes graph
approxPartitioning
= map (map fst)
$ groupBy ((==) `on` snd)
$ sortBy (compare `on` snd)
$ hashes
nodeEq :: NodeId -> NodeId -> Bool
nodeEq n1 n2 = runReader
(liftASTB2 alphaEqM (nTab!n1) (nTab!n2))
(([],(hTab,nTab)) :: EqEnv (NodeDomain dom) (Sat dom))
cse
:: ( Equality dom
, AlphaEq dom dom (NodeDomain dom) (EqEnv (NodeDomain dom) (Sat dom))
)
=> ASG dom a -> ASG dom a
cse graph@(ASG top nodes n) = nubNodes $ reindexNodes (reixTab!) graph
where
parts = partitionNodes graph
reixTab = array (0,n1) [(n,p) | (part,p) <- parts `zip` [0..], n <- part]