{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternGuards #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeSynonymInstances #-}
{-# LANGUAGE ViewPatterns #-}
module Data.Array.Accelerate.Pretty.Graphviz (
Graph,
PrettyGraph(..), Detail(..),
graphDelayedAcc, graphDelayedAfun,
) where
import Data.Array.Accelerate.AST
import Data.Array.Accelerate.AST.Idx
import Data.Array.Accelerate.AST.LeftHandSide
import Data.Array.Accelerate.AST.Var
import Data.Array.Accelerate.Analysis.Match
import Data.Array.Accelerate.Error
import Data.Array.Accelerate.Pretty.Graphviz.Monad
import Data.Array.Accelerate.Pretty.Graphviz.Type
import Data.Array.Accelerate.Pretty.Print hiding ( Keyword(..) )
import Data.Array.Accelerate.Representation.Array
import Data.Array.Accelerate.Representation.Stencil
import Data.Array.Accelerate.Representation.Type
import Data.Array.Accelerate.Sugar.Foreign
import Data.Array.Accelerate.Trafo.Delayed
import Data.Array.Accelerate.Trafo.Substitution
import Control.Applicative hiding ( Const, empty )
import Control.Arrow ( (&&&) )
import Control.Monad.State ( modify, gets, state )
import Data.HashSet ( HashSet )
import Data.List ( nub, partition )
import Data.Maybe
import Data.String
import Data.Text.Prettyprint.Doc
import System.IO.Unsafe ( unsafePerformIO )
import Prelude hiding ( exp )
import qualified Data.HashSet as Set
import qualified Data.Sequence as Seq
cfgIncludeShape, cfgUnique :: Bool
cfgIncludeShape = False
cfgUnique = False
data Aval env where
Aempty :: Aval ()
Apush :: Aval env -> NodeId -> Label -> Aval (env, t)
avalToVal :: Aval aenv -> Val aenv
avalToVal Aempty = Empty
avalToVal (Apush aenv _ v) = Push (avalToVal aenv) (pretty v)
aprj :: Idx aenv t -> Aval aenv -> (NodeId, Label)
aprj ZeroIdx (Apush _ n v) = (n,v)
aprj (SuccIdx ix) (Apush aenv _ _) = aprj ix aenv
mkNode :: PNode -> Maybe Label -> Dot NodeId
mkNode (PNode ident tree deps) label =
let node = Node label ident tree
edges = Seq.fromList
$ map (\(from, to) -> Edge from (Vertex ident to))
$ if cfgUnique then nub deps else deps
in
state $ \s ->
( ident
, s { dotNodes = node Seq.<| dotNodes s
, dotEdges = edges Seq.>< dotEdges s
}
)
mkTF :: Tree (Maybe Port, Adoc) -> Tree (Maybe Port, Adoc)
mkTF this =
Forest [ this
, Forest [ Leaf (Just "T", "T")
, Leaf (Just "F", "F")
]
]
class PrettyGraph g where
ppGraph :: Detail -> g -> Graph
instance PrettyGraph (DelayedAcc a) where
ppGraph = graphDelayedAcc
instance PrettyGraph (DelayedAfun a) where
ppGraph = graphDelayedAfun
data Detail = Simple | Full
simple :: Detail -> Bool
simple Simple = True
simple _ = False
{-# NOINLINE graphDelayedAcc #-}
graphDelayedAcc :: HasCallStack => Detail -> DelayedAcc a -> Graph
graphDelayedAcc detail acc =
unsafePerformIO $! evalDot (graphDelayedOpenAcc detail Aempty acc)
{-# NOINLINE graphDelayedAfun #-}
graphDelayedAfun :: HasCallStack => Detail -> DelayedAfun f -> Graph
graphDelayedAfun detail afun = unsafePerformIO . evalDot $! do
l <- prettyDelayedAfun detail Aempty afun
state $ \s ->
case Seq.viewl (dotGraph s) of
g@(Graph l' _) Seq.:< gs | l == l' -> (g, s { dotGraph = gs })
_ -> internalError "unexpected error"
data PDoc = PDoc Adoc [Vertex]
data PNode = PNode NodeId (Tree (Maybe Port, Adoc)) [(Vertex, Maybe Port)]
graphDelayedOpenAcc
:: HasCallStack
=> Detail
-> Aval aenv
-> DelayedOpenAcc aenv a
-> Dot Graph
graphDelayedOpenAcc detail aenv acc = do
r <- prettyDelayedOpenAcc detail context0 aenv acc
i <- mkNodeId r
v <- mkNode r Nothing
_ <- mkNode (PNode i (Leaf (Nothing,"result")) [(Vertex v Nothing, Nothing)]) Nothing
mkGraph
prettyDelayedOpenAcc
:: forall aenv arrs. HasCallStack
=> Detail
-> Context
-> Aval aenv
-> DelayedOpenAcc aenv arrs
-> Dot PNode
prettyDelayedOpenAcc _ _ _ Delayed{} = internalError "expected manifest array"
prettyDelayedOpenAcc detail ctx aenv atop@(Manifest pacc) =
case pacc of
Avar ix -> pnode (avar ix)
Alet lhs bnd body -> do
bnd'@(PNode ident _ _) <- prettyDelayedOpenAcc detail context0 aenv bnd
(aenv1, a) <- prettyLetALeftHandSide ident aenv lhs
_ <- mkNode bnd' (Just a)
body' <- prettyDelayedOpenAcc detail context0 aenv1 body
return body'
Acond p t e -> do
ident <- mkNodeId atop
vt <- lift t
ve <- lift e
PDoc p' vs <- ppE p
let port = Just "P"
doc = mkTF $ Leaf (port, if simple detail then "?|" else p')
deps = (vt, Just "T") : (ve, Just "F") : map (,port) vs
return $ PNode ident doc deps
Apply _ afun acc -> apply <$> prettyDelayedAfun detail aenv afun
<*> prettyDelayedOpenAcc detail ctx aenv acc
Awhile p f x -> do
ident <- mkNodeId atop
x' <- replant =<< prettyDelayedOpenAcc detail app aenv x
p' <- prettyDelayedAfun detail aenv p
f' <- prettyDelayedAfun detail aenv f
let PNode _ (Leaf (Nothing,xb)) fvs = x'
loop = nest 2 (sep ["awhile", pretty p', pretty f', xb ])
return $ PNode ident (Leaf (Nothing,loop)) fvs
a@(Apair a1 a2) -> mkNodeId a >>= prettyDelayedApair detail aenv a1 a2
Anil -> "()" .$ []
Use repr arr -> "use" .$ [ return $ PDoc (prettyArray repr arr) [] ]
Unit _ e -> "unit" .$ [ ppE e ]
Generate _ sh f -> "generate" .$ [ ppE sh, ppF f ]
Transform _ sh ix f xs -> "transform" .$ [ ppE sh, ppF ix, ppF f, ppA xs ]
Reshape _ sh xs -> "reshape" .$ [ ppE sh, ppA xs ]
Replicate _ty ix xs -> "replicate" .$ [ ppE ix, ppA xs ]
Slice _ty xs ix -> "slice" .$ [ ppA xs, ppE ix ]
Map _ f xs -> "map" .$ [ ppF f, ppA xs ]
ZipWith _ f xs ys -> "zipWith" .$ [ ppF f, ppA xs, ppA ys ]
Fold f (Just z) a -> "fold" .$ [ ppF f, ppE z, ppA a ]
Fold f Nothing a -> "fold1" .$ [ ppF f, ppA a ]
FoldSeg _ f (Just z) a s -> "foldSeg" .$ [ ppF f, ppE z, ppA a, ppA s ]
FoldSeg _ f Nothing a s -> "fold1Seg" .$ [ ppF f, ppA a, ppA s ]
Scan d f (Just z) a -> ppD "scan" d "" .$ [ ppF f, ppE z, ppA a ]
Scan d f Nothing a -> ppD "scan" d "1" .$ [ ppF f, ppA a ]
Scan' d f z a -> ppD "scan" d "'" .$ [ ppF f, ppE z, ppA a ]
Permute f dfts p xs -> "permute" .$ [ ppF f, ppA dfts, ppF p, ppA xs ]
Backpermute _ sh p xs -> "backpermute" .$ [ ppE sh, ppF p, ppA xs ]
Stencil s _ sten bndy xs -> "stencil" .$ [ ppF sten, ppB (stencilEltR s) bndy, ppA xs ]
Stencil2 s1 s2 _ sten bndy1 acc1 bndy2 acc2
-> "stencil2" .$ [ ppF sten, ppB (stencilEltR s1) bndy1, ppA acc1, ppB (stencilEltR s2) bndy2, ppA acc2 ]
Aforeign _ ff _afun xs -> "aforeign" .$ [ return (PDoc (pretty (strForeign ff)) []), ppA xs ]
where
(.$) :: Operator -> [Dot PDoc] -> Dot PNode
name .$ docs = pnode =<< fmt name docs
fmt :: Operator -> [Dot PDoc] -> Dot PDoc
fmt name docs = do
docs' <- sequence docs
let args = [ x | PDoc x _ <- docs' ]
fvs = [ x | PDoc _ x <- docs' ]
doc = if simple detail
then manifest name
else parensIf (needsParens ctx name)
$ nest shiftwidth
$ sep ( manifest name : args )
return $ PDoc doc (concat fvs)
pnode :: PDoc -> Dot PNode
pnode (PDoc doc vs) = do
let port = Nothing
ident <- mkNodeId atop
return $ PNode ident (Leaf (port, doc)) (map (,port) vs)
fvF :: Fun aenv t -> [Vertex]
fvF = fvOpenFun Empty aenv
fvE :: Exp aenv t -> [Vertex]
fvE = fvOpenExp Empty aenv
avar :: ArrayVar aenv t -> PDoc
avar (Var _ ix) = let (ident, v) = aprj ix aenv
in PDoc (pretty v) [Vertex ident Nothing]
aenv' :: Val aenv
aenv' = avalToVal aenv
ppA :: HasCallStack => DelayedOpenAcc aenv a -> Dot PDoc
ppA (Manifest (Avar ix)) = return (avar ix)
ppA acc@Manifest{} = do
acc' <- prettyDelayedOpenAcc detail app aenv acc
v <- mkLabel
ident <- mkNode acc' (Just v)
return $ PDoc (pretty v) [Vertex ident Nothing]
ppA (Delayed _ sh f _)
| Shape a <- sh
, Just b <- isIdentityIndexing f
, Just Refl <- matchVar a b
= ppA $ Manifest $ Avar a
ppA (Delayed _ sh f _) = do
PDoc d v <- "Delayed" `fmt` [ ppE sh, ppF f ]
return $ PDoc (parens d) v
ppB :: forall sh e. HasCallStack
=> TypeR e
-> Boundary aenv (Array sh e)
-> Dot PDoc
ppB _ Clamp = return (PDoc "clamp" [])
ppB _ Mirror = return (PDoc "mirror" [])
ppB _ Wrap = return (PDoc "wrap" [])
ppB tp (Constant e) = return (PDoc (prettyConst tp e) [])
ppB _ (Function f) = ppF f
ppF :: HasCallStack => Fun aenv t -> Dot PDoc
ppF = return . uncurry PDoc . (parens . prettyFun aenv' &&& fvF)
ppE :: HasCallStack => Exp aenv t -> Dot PDoc
ppE = return . uncurry PDoc . (prettyExp aenv' &&& fvE)
ppD :: String -> Direction -> String -> Operator
ppD f LeftToRight k = fromString (f <> "l" <> k)
ppD f RightToLeft k = fromString (f <> "r" <> k)
lift :: HasCallStack => DelayedOpenAcc aenv a -> Dot Vertex
lift Delayed{} = internalError "expected manifest array"
lift (Manifest (Avar (Var _ ix))) = return $ Vertex (fst (aprj ix aenv)) Nothing
lift acc = do
acc' <- prettyDelayedOpenAcc detail context0 aenv acc
ident <- mkNode acc' Nothing
return $ Vertex ident Nothing
apply :: Label -> PNode -> PNode
apply f (PNode ident x vs) =
let x' = case x of
Leaf (p,d) -> Leaf (p, pretty f <+> d)
Forest ts -> Forest (Leaf (Nothing,pretty f) : ts)
in
PNode ident x' vs
prettyDelayedAfun
:: HasCallStack
=> Detail
-> Aval aenv
-> DelayedOpenAfun aenv afun
-> Dot Label
prettyDelayedAfun detail aenv afun = do
Graph _ ss <- mkSubgraph (go aenv afun)
n <- Seq.length <$> gets dotGraph
let label = "afun" <> fromString (show (n+1))
outer = collect aenv
(lifted,ss') =
flip partition ss $ \s ->
case s of
E (Edge (Vertex ident _) _) -> Set.member ident outer
_ -> False
modify $ \s -> s { dotGraph = dotGraph s Seq.|> Graph label ss'
, dotEdges = Seq.fromList [ e | E e <- lifted ] Seq.>< dotEdges s
}
return label
where
go :: Aval aenv' -> DelayedOpenAfun aenv' a' -> Dot Graph
go aenv' (Abody b) = graphDelayedOpenAcc detail aenv' b
go aenv' (Alam lhs f) = do
aenv'' <- prettyLambdaALeftHandSide aenv' lhs
go aenv'' f
collect :: Aval aenv' -> HashSet NodeId
collect Aempty = Set.empty
collect (Apush a i _) = Set.insert i (collect a)
prettyLetALeftHandSide
:: forall repr aenv aenv'. HasCallStack
=> NodeId
-> Aval aenv
-> ALeftHandSide repr aenv aenv'
-> Dot (Aval aenv', Label)
prettyLetALeftHandSide _ aenv (LeftHandSideWildcard repr) = return (aenv, doc)
where
doc = case repr of
TupRunit -> "()"
_ -> "_"
prettyLetALeftHandSide ident aenv (LeftHandSideSingle _) = do
a <- mkLabel
return (Apush aenv ident a, a)
prettyLetALeftHandSide ident aenv (LeftHandSidePair lhs1 lhs2) = do
(aenv1, d1) <- prettyLetALeftHandSide ident aenv lhs1
(aenv2, d2) <- prettyLetALeftHandSide ident aenv1 lhs2
return (aenv2, "(" <> d1 <> ", " <> d2 <> ")")
prettyLambdaALeftHandSide
:: forall repr aenv aenv'. HasCallStack
=> Aval aenv
-> ALeftHandSide repr aenv aenv'
-> Dot (Aval aenv')
prettyLambdaALeftHandSide aenv (LeftHandSideWildcard _) = return aenv
prettyLambdaALeftHandSide aenv lhs@(LeftHandSideSingle _) = do
a <- mkLabel
ident <- mkNodeId lhs
_ <- mkNode (PNode ident (Leaf (Nothing, pretty a)) []) Nothing
return $ Apush aenv ident a
prettyLambdaALeftHandSide aenv (LeftHandSidePair lhs1 lhs2) = do
aenv1 <- prettyLambdaALeftHandSide aenv lhs1
prettyLambdaALeftHandSide aenv1 lhs2
prettyDelayedApair
:: forall aenv a1 a2. HasCallStack
=> Detail
-> Aval aenv
-> DelayedOpenAcc aenv a1
-> DelayedOpenAcc aenv a2
-> NodeId
-> Dot PNode
prettyDelayedApair detail aenv a1 a2 ident = do
PNode id1 t1 v1 <- prettyElem a1
PNode id2 t2 v2 <- prettyElem a2
modify $ \s -> s { dotEdges = fmap (redirect ident [id1, id2]) (dotEdges s) }
return $ PNode ident (forest [t1, t2]) (v1 ++ v2)
where
prettyElem :: DelayedOpenAcc aenv a -> Dot PNode
prettyElem a = replant =<< prettyDelayedOpenAcc detail context0 aenv a
redirect :: NodeId -> [NodeId] -> Edge -> Edge
redirect new subs edge@(Edge from (Vertex to port))
| to `elem` subs = Edge from (Vertex new port)
| otherwise = edge
forest :: [Tree (Maybe Port, Adoc)] -> Tree (Maybe Port, Adoc)
forest leaves = Leaf (Nothing, tupled [ align d | Leaf (Nothing,d) <- leaves ])
replant :: PNode -> Dot PNode
replant pnode@(PNode ident tree _) =
case tree of
Leaf (Nothing, _) -> return pnode
_ -> do
vacuous <- mkNodeId pnode
a <- mkLabel
_ <- mkNode pnode (Just a)
return $ PNode vacuous (Leaf (Nothing, pretty a)) [(Vertex ident Nothing, Nothing)]
fvAvar :: Aval aenv -> ArrayVar aenv a -> [Vertex]
fvAvar env (Var _ ix) = [ Vertex (fst $ aprj ix env) Nothing ]
fvOpenFun
:: forall env aenv fun.
Val env
-> Aval aenv
-> OpenFun env aenv fun
-> [Vertex]
fvOpenFun env aenv (Body b) = fvOpenExp env aenv b
fvOpenFun env aenv (Lam lhs f) = fvOpenFun env' aenv f
where
(env', _) = prettyELhs True env lhs
fvOpenExp
:: forall env aenv exp.
Val env
-> Aval aenv
-> OpenExp env aenv exp
-> [Vertex]
fvOpenExp env aenv = fv
where
fvF :: OpenFun env aenv f -> [Vertex]
fvF = fvOpenFun env aenv
fv :: OpenExp env aenv e -> [Vertex]
fv (Shape acc) = if cfgIncludeShape then fvAvar aenv acc else []
fv (Index acc i) = concat [ fvAvar aenv acc, fv i ]
fv (LinearIndex acc i) = concat [ fvAvar aenv acc, fv i ]
fv (Let lhs e1 e2) = concat [ fv e1, fvOpenExp env' aenv e2 ]
where
(env', _) = prettyELhs False env lhs
fv Evar{} = []
fv Undef{} = []
fv Const{} = []
fv PrimConst{} = []
fv (PrimApp _ x) = fv x
fv (Pair e1 e2) = concat [ fv e1, fv e2]
fv Nil = []
fv (VecPack _ e) = fv e
fv (VecUnpack _ e) = fv e
fv (IndexSlice _ slix sh) = concat [ fv slix, fv sh ]
fv (IndexFull _ slix sh) = concat [ fv slix, fv sh ]
fv (ToIndex _ sh ix) = concat [ fv sh, fv ix ]
fv (FromIndex _ sh ix) = concat [ fv sh, fv ix ]
fv (ShapeSize _ sh) = fv sh
fv Foreign{} = []
fv (Case e rhs def) = concat [ fv e, concat [ fv c | (_,c) <- rhs ], maybe [] fv def ]
fv (Cond p t e) = concat [ fv p, fv t, fv e ]
fv (While p f x) = concat [ fvF p, fvF f, fv x ]
fv (Coerce _ _ e) = fv e