{-# LANGUAGE DeriveGeneric           #-}
{-# LANGUAGE QuantifiedConstraints   #-}
{-# LANGUAGE UndecidableInstances    #-}
{-# LANGUAGE UndecidableSuperClasses #-}
{-# LANGUAGE FlexibleContexts        #-}
{-# LANGUAGE FlexibleInstances       #-}
{-# LANGUAGE ConstraintKinds         #-}
{-# LANGUAGE DefaultSignatures       #-}
{-# LANGUAGE StandaloneDeriving      #-}
{-# LANGUAGE TypeOperators           #-}
{-# LANGUAGE DataKinds               #-}
{-# LANGUAGE PolyKinds               #-}
{-# LANGUAGE GADTs                   #-}
{-# LANGUAGE KindSignatures          #-}
{-# LANGUAGE PatternSynonyms         #-}
{-# LANGUAGE RankNTypes              #-}
{-# LANGUAGE TypeFamilies            #-}
{-# LANGUAGE MultiParamTypeClasses   #-}
{-# LANGUAGE ScopedTypeVariables     #-}
{-# OPTIONS_GHC -Wno-orphans         #-}
-- |Deep representation for 'SRep'
module Generics.Simplistic.Deep
  ( -- * (Co)Free (Co)Monad and its cousins
    HolesAnn(..)
  , SFix    , pattern SFix    , pattern Prim
  , SFixAnn , pattern SFixAnn , pattern PrimAnn
  , Holes   , pattern Roll    , pattern Hole
  -- ** Constraints
  , CompoundCnstr , PrimCnstr
  -- ** Coercions
  , holesToSFix , sfixToHoles
  -- ** Maps, zips and folds
  , holesMapAnn , holesMap , holesMapM , holesMapAnnM , getAnn
  , holesJoin , holesSize, holesHolesList
  , holesRefineM , holesRefineHoles , holesRefineHolesM
  , synthesize , synthesizeM , cataM
  -- ** Anti-Unification
  , lgg
  -- ** Conversion
  , Deep(..) , GDeep(..)
  ) where

import Data.Proxy
-- import qualified Data.Set as S (Set, fromList)
import Control.Monad.Identity
import Control.DeepSeq
import GHC.Generics (from , to)
import Unsafe.Coerce

import Generics.Simplistic
import Generics.Simplistic.Util

-- Useful constraints

type PrimCnstr kappa fam b
  = (Elem b kappa , NotElem b fam)

type CompoundCnstr kappa fam a
  = (Elem a fam , NotElem a kappa , Generic a)

-- |The cofree comonad and free monad on the same type;
-- this allows us to use the same recursion operator
-- for everything.
data HolesAnn kappa fam ann h a where
  Hole' :: ann a -- ^ Annotation
        -> h a -> HolesAnn kappa fam ann h a
  Prim' :: (PrimCnstr kappa fam a)
        => ann a -- ^ Annotation
        -> a -> HolesAnn kappa fam ann h a
  Roll' :: (CompoundCnstr kappa fam a)
        => ann a -- ^ Annotation
        -> SRep (HolesAnn kappa fam ann h) (Rep a)
        -> HolesAnn kappa fam ann h a

instance (All Eq kappa , EqHO h) => EqHO (Holes kappa fam h) where
  eqHO x y = all (exElim $ uncurry' go) $ holesHolesList (lgg x y)
    where
      go :: Holes kappa fam h a -> Holes kappa fam h a -> Bool
      go (Hole h1) (Hole h2) = eqHO h1 h2
      go _         _         = False

instance (All Eq kappa , EqHO h) => Eq (Holes kappa fam h t) where
   (==) = eqHO

-- |Deep representations are easily achieved by forbiding
-- the 'Hole'' constructor and providing unit annotations.
type SFix kappa fam = HolesAnn kappa fam U1 V1

pattern SFix :: () => (CompoundCnstr kappa fam a)
             => SRep (SFix kappa fam) (Rep a)
             -> SFix kappa fam a
pattern SFix x = Roll x
{-# COMPLETE SFix , Prim #-}

-- |A tree with holes has unit annotations
type Holes kappa fam = HolesAnn kappa fam U1

pattern Hole :: h a -> Holes kappa fam h a
pattern Hole x = Hole' U1 x

pattern Prim :: () => (PrimCnstr kappa fam a)
             => a -> Holes kappa fam h a
pattern Prim a = Prim' U1 a

pattern Roll :: () => (CompoundCnstr kappa fam a)
             => SRep (Holes kappa fam h) (Rep a)
             -> Holes kappa fam h a
pattern Roll x = Roll' U1 x
{-# COMPLETE Hole , Prim , Roll #-}

-- |Annotated fixpoints are also easy; forbid the 'Hole''
-- constructor but add something to every 'Roll' of
-- the representation.
type SFixAnn kappa fam ann = HolesAnn kappa fam ann V1

pattern PrimAnn :: () => (PrimCnstr kappa fam a)
                => ann a -> a -> SFixAnn kappa fam ann a
pattern PrimAnn ann a = Prim' ann a

pattern SFixAnn :: () => (CompoundCnstr kappa fam a)
                => ann a
                -> SRep (SFixAnn kappa fam ann) (Rep a)
                -> SFixAnn kappa fam ann a
pattern SFixAnn ann x = Roll' ann x
{-# COMPLETE SFixAnn , PrimAnn #-}

---------------
-- Coercions --
---------------

sfixToHoles :: SFix kappa fam at -> Holes kappa fam h at
sfixToHoles = unsafeCoerce

holesToSFix :: Holes kappa fam V1 at -> SFix kappa fam at
holesToSFix = id

------------
-- NFData --
------------

-- VCM: QUESTION: DDoes it make sense to have this here?
-- I need it in /hdiff/, and I can see how it can be useful.
-- @trupill, do you prefer to keep this or trash this?

instance (forall x . NFData (ann x) , forall x . NFData (h x))
    => NFData (HolesAnn kappa fam ann h f) where
  rnf (Prim' ann _) = rnf ann
  rnf (Hole' ann h) = rnf ann `seq` rnf h
  rnf (Roll' ann x) = rnf ann `seq` rnf x

instance NFData (V1 x) where
  rnf _ = ()

instance NFData (U1 x) where
  rnf U1 = ()

----------------------
-- Useful Functions --
----------------------

-- |Retrieves the annotation inside a 'HolesAnn';
-- this is the counit of the comonad.
getAnn :: HolesAnn kappa fam ann h a -> ann a
getAnn (Hole' ann _) = ann
getAnn (Prim' ann _) = ann
getAnn (Roll' ann _) = ann

-- TODO: swap parameters
-- |Maps over a 'HolesAnn' treating annotations and holes
-- independently.
holesMapAnnM :: (Monad m)
             => (forall x . f x   -> m (g x)) -- ^ Function to transform holes
             -> (forall x . ann x -> m (psi x)) -- ^ Function to transform annotations
             -> HolesAnn kappa fam ann f a -> m (HolesAnn kappa fam psi g a)
holesMapAnnM f g (Hole' a x)   = Hole' <$> g a <*> f x
holesMapAnnM _ g (Prim' a x)   = flip Prim' x <$> g a
holesMapAnnM f g (Roll' a x) = Roll' <$> g a <*> repMapM (holesMapAnnM f g) x

-- |Maps over 'HolesAnn' maintaining annotations intact.
holesMapM :: (Monad m)
          => (forall x . f x -> m (g x))
          -> HolesAnn kappa fam ann f a -> m (HolesAnn kappa fam ann g a)
holesMapM f = holesMapAnnM f return

-- |Maps over the holes in a 'HolesAnn'
holesMap :: (forall x . f x -> g x)
         -> HolesAnn kappa fam ann f a -> HolesAnn kappa fam ann g a
holesMap f = runIdentity . holesMapM (return . f)

-- |Maps over holes and annotations in a 'HolesAnn'
holesMapAnn :: (forall x . f   x -> g x)
            -> (forall x . ann x -> phi x)
            -> HolesAnn kappa fam ann f a -> HolesAnn kappa fam phi g a
holesMapAnn f g = runIdentity . holesMapAnnM (return . f) (return . g)

-- |Monadic multiplication
holesJoin :: HolesAnn kappa fam ann (HolesAnn kappa fam ann f) a
          -> HolesAnn kappa fam ann f a
holesJoin (Hole' _ x) = x
holesJoin (Prim' a x) = Prim' a x
holesJoin (Roll' a x) = Roll' a (repMap holesJoin x)

-- |Computes the list of holes in a 'HolesAnn'
holesHolesList :: HolesAnn kappa fam ann f a -> [Exists f]
holesHolesList (Hole' _ x) = [Exists x]
holesHolesList (Prim' _ _) = []
holesHolesList (Roll' _ x) = concatMap (exElim holesHolesList) $ repLeavesList x

{-
holesHolesSet :: (Ord (Exists f)) => Holes kappa fam f a -> S.Set (Exists f)
holesHolesSet = S.fromList . holesHolesList
-}

-- TODO: Implement holesMap in terms of refine; its much better!

-- |Refines holes using a monadic action
holesRefineHolesM :: (Monad m)
                  => (forall b . f b -> m (Holes kappa fam g b))
                  -> Holes kappa fam f a
                  -> m (Holes kappa fam g a)
holesRefineHolesM f = fmap holesJoin . holesMapM f

-- |Refine holes with a simple action
holesRefineHoles :: (forall b . f b -> Holes kappa fam g b)
                 -> Holes kappa fam f a
                 -> Holes kappa fam g a
holesRefineHoles f = holesJoin . runIdentity . holesMapM (return . f)

-- |Refine holes and primitives
holesRefineM :: (Monad m)
             => (forall b . f b -> m (Holes kappa fam g b))
             -> (forall b . (PrimCnstr kappa fam b)
                  => b -> m (Holes kappa fam g b))
             -> Holes kappa fam f a
             -> m (Holes kappa fam g a)
holesRefineM f _ (Hole x) = f x
holesRefineM _ g (Prim x) = g x
holesRefineM f g (Roll x) = Roll <$> repMapM (holesRefineM f g) x

-- |Counts how many 'Prim's and 'Roll's are inside a 'HolesAnn'.
holesSize :: HolesAnn kappa fam ann h a -> Int
holesSize (Hole' _ _) = 0
holesSize (Prim' _ _) = 1
holesSize (Roll' _ x) = 1 + sum (map (exElim holesSize) $ repLeavesList x)

-- |Catamorphism over 'HolesAnn'
cataM :: (Monad m)
      => (forall b . (CompoundCnstr kappa fam b)
            => ann b -> SRep phi (Rep b) -> m (phi b)) -- ^ How to handle recursion
      -> (forall b . (PrimCnstr kappa fam b)
            => ann b -> b -> m (phi b)) -- ^ How to handle primitivies
      -> (forall b . ann b -> h b -> m (phi b)) -- ^ How to handle holes
      -> HolesAnn kappa fam ann h a
      -> m (phi a)
cataM f g h (Roll' ann x) = repMapM (cataM f g h) x >>= f ann
cataM _ g _ (Prim' ann x) = g ann x
cataM _ _ h (Hole' ann x) = h ann x

-- |Synthetization of attributes
synthesizeM :: (Monad m)
            => (forall b . (CompoundCnstr kappa fam b)
                  => ann b -> SRep phi (Rep b) -> m (phi b)) -- ^ How to handle recursion
            -> (forall b . (PrimCnstr kappa fam b)
                  => ann b -> b -> m (phi b)) -- ^ How to handle primitives
           -> (forall b . ann b -> h b -> m (phi b)) -- ^ How to handle holes
            -> HolesAnn kappa fam ann h a
            -> m (HolesAnn kappa fam phi h a)
synthesizeM f g h = cataM (\ann r -> flip Roll' r
                                <$> f ann (repMap getAnn r))
                          (\ann b -> flip Prim' b <$> g ann b)
                          (\ann r -> flip Hole' r <$> h ann r)

-- |Simpler version of 'synthesizeM' working over the /Identity/ monad.
synthesize :: (forall b . (CompoundCnstr kappa fam b)
                 => ann b -> SRep phi (Rep b) -> phi b)
           -> (forall b . (PrimCnstr kappa fam b)
                 => ann b -> b -> phi b)
           -> (forall b . ann b -> h b -> phi b)
           -> HolesAnn kappa fam ann h a
           -> HolesAnn kappa fam phi h a
synthesize f g h = runIdentity
                 . synthesizeM (\ann -> return . f ann)
                               (\ann -> return . g ann)
                               (\ann -> return . h ann)

-- Anti unification is so simple it doesn't
-- deserve its own module

-- |Computes the /least general generalization/ of two
-- trees.
lgg :: forall kappa fam h i a
     . (All Eq kappa)
    => Holes kappa fam h a -> Holes kappa fam i a
    -> Holes kappa fam (Holes kappa fam h :*: Holes kappa fam i) a
lgg (Prim x) (Prim y)
  | weq (Proxy :: Proxy kappa) x y = Prim x
  | otherwise                      = Hole (Prim x :*: Prim y)
lgg x@(Roll rx) y@(Roll ry) =
  case zipSRep rx ry of
    Nothing -> Hole (x :*: y)
    Just r  -> Roll (repMap (uncurry' lgg) r)
lgg x y = Hole (x :*: y)

----------------------
-- Deep translation --
----------------------

{- It is possible to have a simler GDeep; relying on
-- GShallow. I'll test performance later.

class GDeep' fam prim isPrim a where
  gdfrom'  :: Proxy isPrim -> a -> SFix fam prim a
  gdto'    :: Proxy isPrim -> SFix fam prim a -> a

instance (CompoundCnstr fam prim a , GDeep fam prim a)
     => GDeep' fam prim 'False a where
  gdfrom' _ a = gdfrom $ a
  gdto' _   x = gdto x

instance (PrimCnstr fam prim a) => GDeep' fam prim 'True a where
  gdfrom' _ a = Prim a
  gdto'   _ (Prim a) = a

class GDeep fam prim a where
  gdfrom :: a -> SFix fam prim a
  gdto   :: SFix fam prim a -> a

instance GDeep' fam prim (IsElem a prim) a => GDeep fam prim a where
  gdfrom = gdfrom' (Proxy :: Proxy (IsElem a prim))
  gdto   = gdto'   (Proxy :: Proxy (IsElem a prim))

dfrom :: forall fam prim a
       . (CompoundCnstr fam prim a)
      => a -> SFix fam prim a
dfrom = SFix
      . runIdentity
      . repMapCM (Proxy :: Proxy (GDeep fam prim))
         (\(I x) -> return $ gdfrom x)
      . fromS

-}

class (CompoundCnstr kappa fam a) => Deep kappa fam a where
  dfrom :: a -> SFix kappa fam a
  default dfrom :: (GDeep kappa fam (Rep a)) => a -> SFix kappa fam a
  dfrom = SFix . gdfrom . from

  dto :: SFix kappa fam a -> a
  default dto :: (GDeep kappa fam (Rep a)) => SFix kappa fam a -> a
  dto (SFix x) = to . gdto $ x

class GDeep kappa fam f where
  gdfrom :: f x -> SRep (SFix kappa fam) f
  gdto   :: SRep (SFix kappa fam) f -> f x

class GDeepAtom kappa fam (isPrim :: Bool) a where
  gdfromAtom  :: Proxy isPrim -> a -> SFix kappa fam a
  gdtoAtom    :: Proxy isPrim -> SFix kappa fam a -> a

instance (CompoundCnstr kappa fam a , Deep kappa fam a)
     => GDeepAtom kappa fam 'False a where
  gdfromAtom _ a = dfrom $ a
  gdtoAtom _   x = dto x

instance (PrimCnstr kappa fam a) => GDeepAtom kappa fam 'True a where
  gdfromAtom _ a = Prim a
  gdtoAtom   _ (Prim a) = a

instance (GDeepAtom kappa fam (IsElem a kappa) a) => GDeep kappa fam (K1 R a) where
  gdfrom (K1 a)   = S_K1 (gdfromAtom (Proxy :: Proxy (IsElem a kappa)) a)
  gdto   (S_K1 a) = K1 (gdtoAtom (Proxy :: Proxy (IsElem a kappa)) a)

instance GDeep kappa fam U1 where
  gdfrom U1  = S_U1
  gdto S_U1 = U1

instance (GDeep kappa fam f , GDeep kappa fam g) => GDeep kappa fam (f :*: g) where
  gdfrom (x :*: y) = (gdfrom x) :**: (gdfrom y)
  gdto (x :**: y) = (gdto x) :*: (gdto y)

instance (GDeep kappa fam f , GDeep kappa fam g) => GDeep kappa fam (f :+: g) where
  gdfrom (L1 x) = S_L1 (gdfrom x)
  gdfrom (R1 x) = S_R1 (gdfrom x)

  gdto (S_L1 x) = L1 (gdto x)
  gdto (S_R1 x) = R1 (gdto x)

instance (GMeta i c , GDeep kappa fam f) => GDeep kappa fam (M1 i c f) where
  gdfrom (M1 x)   = S_M1 smeta (gdfrom x)
  gdto (S_M1 _ x) = M1 (gdto x)