-----------------------------------------------------------------------------

-- Copyright 2018, Ideas project team. This file is distributed under the

-- terms of the Apache License 2.0. For more information, see the files

-- "LICENSE.txt" and "NOTICE.txt", which are included in the distribution.

-----------------------------------------------------------------------------

-- |

-- Maintainer  :  bastiaan.heeren@ou.nl

-- Stability   :  provisional

-- Portability :  portable (depends on ghc)

--

-----------------------------------------------------------------------------



module Ideas.Common.Strategy.CyclicTree

   ( -- * Data type

     CyclicTree

     -- * Constructor functions

   , node, node0, node1, node2, leaf, label

     -- * Querying

   , isNode, isLeaf, isLabel

     -- * Replace functions

   , replaceNode, replaceLeaf, replaceLabel, shrinkTree

     -- * Fold and algebra

   , fold, foldUnwind

   , CyclicTreeAlg, fNode, fLeaf, fLabel, fRec, fVar

   , emptyAlg, monoidAlg

   ) where



import Control.Monad

import Data.List (intercalate)

import Ideas.Common.Classes

import Ideas.Common.Id

import Test.QuickCheck hiding (label)



--------------------------------------------------------------

-- Data type



data CyclicTree a b

   = Node a [CyclicTree a b]

   | Leaf b

   | Label Id (CyclicTree a b)

   | Rec Int (CyclicTree a b)

   | Var Int



instance (Show a, Show b) => Show (CyclicTree a b) where

   show = fold Alg

      { fNode  = \a xs -> show a ++ par xs

      , fLeaf  = show

      , fLabel = \l s -> show l ++ ":" ++ s

      , fRec   = \n s -> '#' : show n ++ "=" ++ s

      , fVar   = \n   -> '#' : show n

      }



instance BiFunctor CyclicTree where

   biMap f g = fold idAlg {fNode = Node . f, fLeaf = Leaf . g}



instance Functor (CyclicTree d) where

   fmap = mapSecond



instance Applicative (CyclicTree d) where

   pure    = leaf

   p <*> q = fold idAlg {fLeaf = (<$> q)} p



instance Monad (CyclicTree d) where

   return = leaf

   (>>=)  = flip replaceLeaf



instance Foldable (CyclicTree d) where

   foldMap f = fold monoidAlg {fLeaf = f}



instance Traversable (CyclicTree d) where

   traverse f = fold emptyAlg

      { fNode  = \a -> fmap (node a) . sequenceA

      , fLeaf  = fmap leaf . f

      , fLabel = fmap . label

      , fRec   = fmap . Rec

      , fVar   = pure . Var

      }



instance Fix (CyclicTree a b) where

   fix f = Rec n (f (Var n))

    where

      vs = vars (f (Var (-1)))

      n  = maximum (-1 : vs) + 1



instance (Arbitrary a, Arbitrary b) => Arbitrary (CyclicTree a b) where

   arbitrary = sized arbTree

   shrink    = shrinkTree



arbTree :: (Arbitrary a, Arbitrary b) => Int -> Gen (CyclicTree a b)

arbTree = rec 0

 where

   rec vi 0 = frequency $

      (3, leaf <$> arbitrary)

      : [ (1, elements (map Var [1..vi])) | vi > 0 ]

   rec vi n = frequency

      [ (3, node <$> arbitrary <*> ms)

      , (2, rec vi 0)

      , (1, label <$> genId <*> m)

      , (1, Rec (vi+1) <$> rec (vi+1) (n `div` 2))

      ]

    where

      m = rec vi (n `div` 2)

      genId = elements [ newId [c] | c <- ['A' .. 'Z']]

      ms = choose (0, 3) >>= \i -> replicateM i m



shrinkTree :: CyclicTree a b -> [CyclicTree a b]

shrinkTree tree =

   case tree of

      Node a ts -> ts ++ map (node a) (shrinkTrees ts)

      Label l t -> t : map (Label l) (shrinkTree t)

      Rec n t   -> map (Rec n) (shrinkTree t)

      _ -> []



-- shrink exactly one tree

shrinkTrees :: [CyclicTree a b] -> [[CyclicTree a b]]

shrinkTrees []    = []

shrinkTrees (t:ts) = map (:ts) (shrinkTree t) ++ map (t:) (shrinkTrees ts)



-- local helpers

par :: [String] -> String

par xs | null xs   = ""

       | otherwise = "(" ++ intercalate ", " xs ++ ")"



vars :: CyclicTree a b -> [Int]

vars = fold monoidAlg {fVar = return}



--------------------------------------------------------------

-- Constructor functions



node :: a -> [CyclicTree a b] -> CyclicTree a b

node = Node



node0 :: a -> CyclicTree a b

node0 a = node a []



node1 :: a -> CyclicTree a b -> CyclicTree a b

node1 a x = node a [x]



node2 :: a -> CyclicTree a b -> CyclicTree a b -> CyclicTree a b

node2 a x y = node a [x, y]



leaf :: b -> CyclicTree a b

leaf = Leaf



label :: IsId n => n -> CyclicTree a b -> CyclicTree a b

label = Label . newId



--------------------------------------------------------------

-- Querying



isNode :: CyclicTree a b -> Maybe (a, [CyclicTree a b])

isNode (Node a xs) = Just (a, xs)

isNode _ = Nothing



isLeaf :: CyclicTree a b -> Maybe b

isLeaf (Leaf b) = Just b

isLeaf _ = Nothing



isLabel :: CyclicTree a b -> Maybe (Id, CyclicTree a b)

isLabel (Label l t) = Just (l, t)

isLabel _ = Nothing



--------------------------------------------------------------

-- Replace functions



replaceNode :: (a -> [CyclicTree a b] -> CyclicTree a b) -> CyclicTree a b -> CyclicTree a b

replaceNode f = fold idAlg {fNode = f}



replaceLabel :: (Id -> CyclicTree a b -> CyclicTree a b) -> CyclicTree a b -> CyclicTree a b

replaceLabel f = fold idAlg {fLabel = f}



replaceLeaf :: (b -> CyclicTree a c) -> CyclicTree a b -> CyclicTree a c

replaceLeaf f = fold idAlg {fLeaf = f}



--------------------------------------------------------------

-- Fold and algebra



fold :: CyclicTreeAlg a b t -> CyclicTree a b -> t

fold alg = rec

 where

   rec (Node a ts) = fNode alg a (map rec ts)

   rec (Leaf b)    = fLeaf alg b

   rec (Label l t) = fLabel alg l (rec t)

   rec (Rec n t)   = fRec alg n (rec t)

   rec (Var n)     = fVar alg n



foldUnwind :: CyclicTreeAlg a b t -> CyclicTree a b -> t

foldUnwind alg = start . fold Alg

   { fNode  = \a fs sub -> fNode alg a (map ($ sub) fs)

   , fLeaf  = \b _      -> fLeaf alg b

   , fLabel = \l f sub  -> fLabel alg l (f sub)

   , fRec   = \n f sub  -> let this = f (extend n this sub)

                           in this

   , fVar   = \n sub    -> sub n

   }

 where

   start f = f (error "foldUnwind: unbound var")

   extend n a sub i

      | i == n    = a

      | otherwise = sub i



data CyclicTreeAlg a b t = Alg

   { fNode  :: a -> [t] -> t

   , fLeaf  :: b -> t

   , fLabel :: Id -> t -> t

   , fRec   :: Int -> t -> t

   , fVar   :: Int -> t

   }



idAlg :: CyclicTreeAlg a b (CyclicTree a b)

idAlg = Alg Node Leaf Label Rec Var



emptyAlg :: CyclicTreeAlg a b t

emptyAlg = let f = error "emptyAlg: uninitialized" in Alg f f f f f



monoidAlg :: Monoid m => CyclicTreeAlg a b m

monoidAlg = Alg (const mconcat) mempty (const id) (const id) mempty