{-# LANGUAGE NoMonomorphismRestriction #-}
module Math.Combinatorics.Digraph where
import Data.List as L
import qualified Data.Map.Strict as M
import qualified Data.Set as S
import Math.Core.Utils (picks, toSet)
data Digraph v = DG [v] [(v,v)] deriving (Eq,Ord,Show)
instance Functor Digraph where
    
    fmap f (DG vs es) = DG (map f vs) (map (\(u,v)->(f u, f v)) es)
nf (DG vs es) = DG (L.sort vs) (L.sort es)
vertices (DG vs _) = vs
edges (DG _ es) = es
predecessors (DG _ es) v = [u | (u,v') <- es, v' == v]
successors (DG _ es) u = [v | (u',v) <- es, u' == u]
adjLists (DG vs es) = adjLists' (M.empty, M.empty) es
    where adjLists' (preds,succs) ((u,v):es) =
              adjLists' (M.insertWith (flip (++)) v [u] preds, M.insertWith (flip (++)) u [v] succs) es
          adjLists' (preds,succs) [] = (preds, succs)
digraphIsos1 (DG vsa esa) (DG vsb esb)
    | length vsa /= length vsb = []
    | length esa /= length esb = []
    | otherwise = digraphIsos' [] vsa vsb
    where digraphIsos' xys [] [] = [xys]
          digraphIsos' xys (x:xs) ys =
              concat [ digraphIsos' ((x,y):xys) xs ys'
                     | (y,ys') <- picks ys, isCompatible (x,y) xys]
          isCompatible (x,y) xys = and [ ((x,x') `elem` esa) == ((y,y') `elem` esb)
                                      && ((x',x) `elem` esa) == ((y',y) `elem` esb)
                                       | (x',y') <- xys ]
digraphIsos2 a b
    | length (vertices a) /= length (vertices b) = []
    | L.sort (M.elems indega) /= L.sort (M.elems indegb) = []
    | L.sort (M.elems outdega) /= L.sort (M.elems outdegb) = []
    | otherwise = dfs [] (vertices a) (vertices b)
    where (preda,succa) = adjLists a
          (predb,succb) = adjLists b
          indega = M.map length preda
          indegb = M.map length predb
          outdega = M.map length succa
          outdegb = M.map length succb
          isCompatible (x,y) xys = (M.findWithDefault 0 x indega) == (M.findWithDefault 0 y indegb)
                                && (M.findWithDefault 0 x outdega) == (M.findWithDefault 0 y outdegb)
                                && and [ (x' `elem` predx) == (y' `elem` predy)
                                      && (x' `elem` succx) == (y' `elem` succy)
                                       | let predx = M.findWithDefault [] x preda, let predy = M.findWithDefault [] y predb,
                                         let succx = M.findWithDefault [] x succa, let succy = M.findWithDefault [] y succb,
                                         (x',y') <- xys]
          dfs xys [] [] = [xys]
          dfs xys (x:xs) ys =
              concat [ dfs ((x,y):xys) xs ys'
                     | (y,ys') <- picks ys, isCompatible (x,y) xys]
heightPartitionDAG dag@(DG vs es) = heightPartition' S.empty [v | v <- vs, v `M.notMember` preds] 
    where (preds,succs) = adjLists dag
          heightPartition' interior boundary
              | null boundary = []
              | otherwise = let interior' = S.union interior $ S.fromList boundary
                                boundary' = toSet [v | u <- boundary, v <- M.findWithDefault [] u succs,
                                                       all (`S.member` interior') (preds M.! v) ]
                            in boundary : heightPartition' interior' boundary'
isDAG dag@(DG vs _) = length vs == length (concat (heightPartitionDAG dag))
dagIsos dagA@(DG vsA esA) dagB@(DG vsB esB)
    | length vsA /= length (concat heightPartA) = error "dagIsos: dagA is not a DAG"
    | length vsB /= length (concat heightPartB) = error "dagIsos: dagB is not a DAG"
    | map length heightPartA /= map length heightPartB = []
    | otherwise = dfs [] heightPartA heightPartB
    where heightPartA = heightPartitionDAG dagA
          heightPartB = heightPartitionDAG dagB
          (predsA,_) = adjLists dagA
          (predsB,_) = adjLists dagB
          dfs xys [] [] = [xys]
          dfs xys ([]:las) ([]:lbs) = dfs xys las lbs
          dfs xys ((x:xs):las) (ys:lbs) =
              concat [ dfs ((x,y):xys) (xs:las) (ys' : lbs)
                     | (y,ys') <- picks ys, isCompatible (x,y) xys]
          isCompatible (x,y) xys =
              let preds_x = M.findWithDefault [] x predsA
                  preds_y = M.findWithDefault [] y predsB
              in and [ (x' `elem` preds_x) == (y' `elem` preds_y) | (x',y') <- xys]
              
              
          
isDagIso :: (Ord a, Ord b) => Digraph a -> Digraph b -> Bool
isDagIso dagA dagB = (not . null) (dagIsos dagA dagB)
perms [] = [[]]
perms (x:xs) = [ls ++ [x] ++ rs | ps <- perms xs, (ls,rs) <- zip (inits ps) (tails ps)]
isoRepDAG1 dag@(DG vs es) = isoRepDAG' [M.empty] 1 (heightPartitionDAG dag)
    where isoRepDAG' initmaps j (level:levels) =
              let j' = j + length level
                  addmaps = [M.fromList (zip ps [j..]) | ps <- perms level]
                  initmaps' = [init +++ add | init <- initmaps, add <- addmaps]
              in isoRepDAG' initmaps' j' levels
          isoRepDAG' maps _ [] = DG [1..length vs] (minimum [L.sort (map (\(u,v) -> (m M.! u, m M.! v)) es) | m <- maps])
          initmap +++ addmap = M.union initmap addmap
isoRepDAG2 dag@(DG vs es) = minimum $ dfs [] srclevels trglevels
    where 
          srclevels = heightPartitionDAG dag
          trglevels = reverse $ fst $ foldl
                      (\(tls,is) sl -> let (js,ks) = splitAt (length sl) is in (js:tls,ks))
                      ([],[1..]) srclevels
          dfs xys [] [] = [xys]
          dfs xys ([]:sls) ([]:tls) = dfs xys sls tls
          dfs xys ((x:xs):sls) (ys:tls) =
              concat [ dfs ((x,y):xys) (xs:sls) (ys' : tls) | (y,ys') <- picks ys]
              
isoRepDAG3 dag@(DG vs es) = dfs root [root]
    where n = length vs
          root = ([],(1,0),M.empty,(srclevels,trglevels)) 
          (preds,succs) = adjLists dag
          srclevels = heightPartitionDAG dag
          trglevels = reverse $ fst $ foldl
                      (\(tls,is) sl -> let (js,ks) = splitAt (length sl) is in (js:tls,ks))
                      ([],[1..]) srclevels
          dfs best (node:stack) =
              
              case cmpPartial best node of
              LT -> dfs best stack                      
              GT -> dfs node (successors node ++ stack) 
              EQ -> dfs best (successors node ++ stack)
          
          dfs best@(es',_,_,_) [] = DG [1..n] es'
          successors (es,_,_,([],[])) = []
          successors (es,(i,j),m,([]:sls,[]:tls)) = successors (es,(i,j),m,(sls,tls))
          successors (es,(i,j),m,(xs:sls,(y:ys):tls)) =
              [ (es', (i',y), m', (L.delete x xs : sls, ys : tls))
              | x <- xs,
                let m' = M.insert x y m,
                let es' = L.sort $ es ++ [(m M.! u, y) | u <- M.findWithDefault [] x preds],
                let i' = nextunfinished m' i ]
          
          nextunfinished m i =
              case [v | (v,i') <- M.assocs m, i' == i] of
              [] -> i
              [u] -> if all (`M.member` m) (M.findWithDefault [] u succs)
                     then nextunfinished m (i+1) 
                     else i
          cmpPartial (es,_,_,_) (es',(i',j'),_,_) =
              cmpPartial' (i',j') es es'
              
          cmpPartial' (i',j') ((u,v):es) ((u',v'):es') =
          
          
              case compare (u,v) (u',v') of
              EQ -> cmpPartial' (i',j') es es'
              LT -> if (u,v) <= (i',j') then LT else EQ
              GT -> GT 
                       
          cmpPartial' (i',j') ((u,v):es) [] = if (u,v) <= (i',j') then LT else EQ
          cmpPartial' _ [] ((u',v'):es') = GT 
          cmpPartial' _ [] [] = EQ
isoRepDAG :: (Ord a) => Digraph a -> Digraph Int
isoRepDAG dag@(DG vs es) = dfs root [root]
    where n = length vs
          root = ([],(1,0),M.empty,(srclevels,trglevels)) 
          (preds,succs) = adjLists dag
          indegs = M.map length preds
          outdegs = M.map length succs
          byDegree vs = (map . map) snd $ L.groupBy (\(du,u) (dv,v) -> du == dv) $ L.sort
                        [( (M.findWithDefault 0 v indegs, M.findWithDefault 0 v outdegs), v) | v <- vs]
          srclevels = concatMap byDegree $ heightPartitionDAG dag
          trglevels = reverse $ fst $ foldl
                      (\(tls,is) sl -> let (js,ks) = splitAt (length sl) is in (js:tls,ks))
                      ([],[1..]) srclevels
          dfs best (node:stack) =
              
              case cmpPartial best node of
              LT -> dfs best stack                      
              GT -> dfs node (successors node ++ stack) 
              EQ -> dfs best (successors node ++ stack)
          
          dfs best@(es',_,_,_) [] = DG [1..n] es'
          successors (es,_,_,([],[])) = []
          successors (es,(i,j),m,([]:sls,[]:tls)) = successors (es,(i,j),m,(sls,tls))
          successors (es,(i,j),m,(xs:sls,(y:ys):tls)) =
              [ (es', (i',y), m', (L.delete x xs : sls, ys : tls))
              | x <- xs,
                let m' = M.insert x y m,
                let es' = L.sort $ es ++ [(m M.! u, y) | u <- M.findWithDefault [] x preds],
                let i' = nextunfinished m' i ]
          
          nextunfinished m i =
              case [v | (v,i') <- M.assocs m, i' == i] of
              [] -> i
              [u] -> if all (`M.member` m) (M.findWithDefault [] u succs)
                     then nextunfinished m (i+1) 
                     else i
          cmpPartial (es,_,_,_) (es',(i',j'),_,_) =
              cmpPartial' (i',j') es es'
              
          cmpPartial' (i',j') ((u,v):es) ((u',v'):es') =
          
          
              case compare (u,v) (u',v') of
              EQ -> cmpPartial' (i',j') es es'
              LT -> if (u,v) <= (i',j') then LT else EQ
              GT -> GT 
                       
          cmpPartial' (i',j') ((u,v):es) [] = if (u,v) <= (i',j') then LT else EQ
          cmpPartial' _ [] ((u',v'):es') = GT 
          cmpPartial' _ [] [] = EQ