{-# LANGUAGE OverloadedStrings #-}

module Calligraphy.Phases.Render.Mermaid
  ( renderMermaid,
  )
where

import Calligraphy.Phases.Render.Common
import Calligraphy.Prelude hiding (Decl, DeclType, Node)
import Calligraphy.Util.Printer
import Calligraphy.Util.Types
import Control.Monad.State (State, execState, modify)
import Data.List (intercalate)
import Data.List.NonEmpty (NonEmpty)
import Data.Set (Set)
import qualified Data.Set as Set
import Data.Tree (Tree (..))

renderMermaid :: Prints RenderGraph
renderMermaid :: Prints RenderGraph
renderMermaid (RenderGraph Either (NonEmpty RenderModule) (NonEmpty (Tree RenderNode))
roots Set (String, String)
calls Set (String, String)
types) = do
  Text -> Printer ()
textLn Text
"flowchart TD"
  forall a. Printer a -> Printer a
indent forall a b. (a -> b) -> a -> b
$ do
    case Either (NonEmpty RenderModule) (NonEmpty (Tree RenderNode))
roots of
      Left NonEmpty RenderModule
modules -> forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Prints RenderModule
printModule NonEmpty RenderModule
modules
      Right NonEmpty (Tree RenderNode)
trees -> forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Prints (Tree RenderNode)
printTree NonEmpty (Tree RenderNode)
trees
    forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (Set (String, String) -> Set (String, String)
removeZeroEdges Set (String, String)
calls) forall a b. (a -> b) -> a -> b
$ \(String
caller, String
callee) ->
      String -> Printer ()
strLn forall a b. (a -> b) -> a -> b
$ String
caller forall a. Semigroup a => a -> a -> a
<> String
" --> " forall a. Semigroup a => a -> a -> a
<> String
callee
    forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (Set (String, String) -> Set (String, String)
removeZeroEdges Set (String, String)
types) forall a b. (a -> b) -> a -> b
$ \(String
caller, String
callee) ->
      String -> Printer ()
strLn forall a b. (a -> b) -> a -> b
$ String
caller forall a. Semigroup a => a -> a -> a
<> String
" -.-> " forall a. Semigroup a => a -> a -> a
<> String
callee
    String -> Printer ()
strLn String
"classDef default fill-opacity:0,stroke:#777;"
  where
    printTree :: Prints (Tree RenderNode)
    printTree :: Prints (Tree RenderNode)
printTree (Node (RenderNode String
nodeid DeclType
typ [String]
lbll Bool
export) []) = do
      String -> Printer ()
strLn forall a b. (a -> b) -> a -> b
$ String
nodeid forall a. Semigroup a => a -> a -> a
<> DeclType -> String -> String
nodeShape DeclType
typ (forall a. [a] -> [[a]] -> [a]
intercalate String
"\\n" [String]
lbll)
      forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
export forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Printer ()
strLn forall a b. (a -> b) -> a -> b
$
        String
"style " forall a. Semigroup a => a -> a -> a
<> String
nodeid forall a. Semigroup a => a -> a -> a
<> String
" stroke-dasharray: 5 5"
    printTree (Node (RenderNode String
nodeid DeclType
_typ [String]
lbll Bool
export) [Tree RenderNode]
children) = do
      forall a. String -> String -> Printer a -> Printer a
brack (String
"subgraph " forall a. Semigroup a => a -> a -> a
<> String
nodeid forall a. Semigroup a => a -> a -> a
<> String
"[" forall a. Semigroup a => a -> a -> a
<> forall a. [a] -> [[a]] -> [a]
intercalate String
"\\n" [String]
lbll forall a. Semigroup a => a -> a -> a
<> String
"]") String
"end" forall a b. (a -> b) -> a -> b
$ do
        forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
export forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Printer ()
strLn forall a b. (a -> b) -> a -> b
$
          String
"style " forall a. Semigroup a => a -> a -> a
<> String
nodeid forall a. Semigroup a => a -> a -> a
<> String
" stroke-dasharray: 5 5"
        forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Tree RenderNode]
children Prints (Tree RenderNode)
printTree
    -- strLn $ nodeid <> " ~~~ " <> nodeId childNode

    printModule :: Prints RenderModule
    printModule :: Prints RenderModule
printModule (RenderModule String
lbl String
modId NonEmpty (Tree RenderNode)
decls) =
      forall a. String -> String -> Printer a -> Printer a
brack (String
"subgraph " forall a. Semigroup a => a -> a -> a
<> String
modId forall a. Semigroup a => a -> a -> a
<> String
" [" forall a. Semigroup a => a -> a -> a
<> String
lbl forall a. Semigroup a => a -> a -> a
<> String
"]") String
"end" forall a b. (a -> b) -> a -> b
$
        forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ NonEmpty (Tree RenderNode)
decls Prints (Tree RenderNode)
printTree

    -- Because we render hierarchies using subgraphs, there's an odd edge case
    -- when theres an edge between a parent and child; mermaid renders these as
    -- a zero-length edge, i.e. _just_ an arrowhead. So, we remove them.
    removeZeroEdges :: Set (ID, ID) -> Set (ID, ID)
    removeZeroEdges :: Set (String, String) -> Set (String, String)
removeZeroEdges = forall s a. State s a -> s -> s
execState (forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Tree RenderNode -> State (Set (String, String)) ()
go NonEmpty (Tree RenderNode)
declRoots)
      where
        declRoots :: NonEmpty (Tree RenderNode)
        declRoots :: NonEmpty (Tree RenderNode)
declRoots = case Either (NonEmpty RenderModule) (NonEmpty (Tree RenderNode))
roots of
          Left NonEmpty RenderModule
mods -> NonEmpty RenderModule
mods forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= RenderModule -> NonEmpty (Tree RenderNode)
moduleDecls
          Right NonEmpty (Tree RenderNode)
a -> NonEmpty (Tree RenderNode)
a
        go :: Tree RenderNode -> State (Set (ID, ID)) ()
        go :: Tree RenderNode -> State (Set (String, String)) ()
go (Node (RenderNode String
parentId DeclType
_ [String]
_ Bool
_) [Tree RenderNode]
children) = do
          forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Tree RenderNode]
children forall a b. (a -> b) -> a -> b
$ \child :: Tree RenderNode
child@(Node (RenderNode String
childId DeclType
_ [String]
_ Bool
_) [Tree RenderNode]
_) -> do
            forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ forall a. Ord a => a -> Set a -> Set a
Set.delete (String
parentId, String
childId)
            Tree RenderNode -> State (Set (String, String)) ()
go Tree RenderNode
child

nodeShape :: DeclType -> String -> String
nodeShape :: DeclType -> String -> String
nodeShape DeclType
DataDecl = String -> String -> String -> String
wrapBracket String
"([" String
"])"
nodeShape DeclType
ValueDecl = String -> String -> String -> String
wrapBracket String
"[" String
"]"
nodeShape DeclType
RecDecl = String -> String -> String -> String
wrapBracket String
"(" String
")"
nodeShape DeclType
ConDecl = String -> String -> String -> String
wrapBracket String
"(" String
")"
nodeShape DeclType
ClassDecl = String -> String -> String -> String
wrapBracket String
"[/" String
"\\]"

wrapBracket :: String -> String -> String -> String
wrapBracket :: String -> String -> String -> String
wrapBracket String
p String
q String
inner = String
p forall a. Semigroup a => a -> a -> a
<> String
inner forall a. Semigroup a => a -> a -> a
<> String
q