{-# LANGUAGE CPP #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE Trustworthy #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}

-- |
-- Module      :   Grisette.Core.Data.Class.ExtractSymbolics
-- Copyright   :   (c) Sirui Lu 2021-2023
-- License     :   BSD-3-Clause (see the LICENSE file)
--
-- Maintainer  :   siruilu@cs.washington.edu
-- Stability   :   Experimental
-- Portability :   GHC only
module Grisette.Core.Data.Class.ExtractSymbolics
  ( -- * Extracting symbolic constant set from a value
    ExtractSymbolics (..),
  )
where

import Control.Monad.Except
import Control.Monad.Identity
import Control.Monad.Trans.Maybe
import qualified Control.Monad.Writer.Lazy as WriterLazy
import qualified Control.Monad.Writer.Strict as WriterStrict
import qualified Data.ByteString as B
import Data.Functor.Sum
import Data.Int
import Data.Word
import Generics.Deriving
import {-# SOURCE #-} Grisette.IR.SymPrim.Data.Prim.Model

-- $setup
-- >>> import Grisette.Core
-- >>> import Grisette.IR.SymPrim
-- >>> import Grisette.Lib.Base
-- >>> import Data.HashSet as HashSet
-- >>> import Data.List (sort)

-- | Extracts all the symbolic variables that are transitively contained in the given value.
--
-- >>> extractSymbolics ("a" :: SymBool) :: SymbolSet
-- SymbolSet {a :: Bool}
--
-- >>> extractSymbolics (mrgIf "a" (mrgReturn ["b"]) (mrgReturn ["c", "d"]) :: UnionM [SymBool]) :: SymbolSet
-- SymbolSet {a :: Bool, b :: Bool, c :: Bool, d :: Bool}
--
-- __Note 1:__ This type class can be derived for algebraic data types.
-- You may need the @DerivingVia@ and @DerivingStrategies@ extensions.
--
-- > data X = ... deriving Generic deriving ExtractSymbolics via (Default X)
class ExtractSymbolics a where
  extractSymbolics :: a -> SymbolSet

instance (Generic a, ExtractSymbolics' (Rep a)) => ExtractSymbolics (Default a) where
  extractSymbolics :: Default a -> SymbolSet
extractSymbolics = Rep a Any -> SymbolSet
forall (a :: * -> *) c. ExtractSymbolics' a => a c -> SymbolSet
extractSymbolics' (Rep a Any -> SymbolSet)
-> (Default a -> Rep a Any) -> Default a -> SymbolSet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Rep a Any
forall a x. Generic a => a -> Rep a x
from (a -> Rep a Any) -> (Default a -> a) -> Default a -> Rep a Any
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Default a -> a
forall a. Default a -> a
unDefault

class ExtractSymbolics' a where
  extractSymbolics' :: a c -> SymbolSet

instance ExtractSymbolics' U1 where
  extractSymbolics' :: forall c. U1 c -> SymbolSet
extractSymbolics' U1 c
_ = SymbolSet
forall a. Monoid a => a
mempty

instance (ExtractSymbolics c) => ExtractSymbolics' (K1 i c) where
  extractSymbolics' :: forall c. K1 i c c -> SymbolSet
extractSymbolics' = c -> SymbolSet
forall a. ExtractSymbolics a => a -> SymbolSet
extractSymbolics (c -> SymbolSet) -> (K1 i c c -> c) -> K1 i c c -> SymbolSet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. K1 i c c -> c
forall k i c (p :: k). K1 i c p -> c
unK1

instance (ExtractSymbolics' a) => ExtractSymbolics' (M1 i c a) where
  extractSymbolics' :: forall c. M1 i c a c -> SymbolSet
extractSymbolics' = a c -> SymbolSet
forall (a :: * -> *) c. ExtractSymbolics' a => a c -> SymbolSet
extractSymbolics' (a c -> SymbolSet)
-> (M1 i c a c -> a c) -> M1 i c a c -> SymbolSet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. M1 i c a c -> a c
forall k i (c :: Meta) (f :: k -> *) (p :: k). M1 i c f p -> f p
unM1

instance
  (ExtractSymbolics' a, ExtractSymbolics' b) =>
  ExtractSymbolics' (a :+: b)
  where
  extractSymbolics' :: forall c. (:+:) a b c -> SymbolSet
extractSymbolics' (L1 a c
l) = a c -> SymbolSet
forall (a :: * -> *) c. ExtractSymbolics' a => a c -> SymbolSet
extractSymbolics' a c
l
  extractSymbolics' (R1 b c
r) = b c -> SymbolSet
forall (a :: * -> *) c. ExtractSymbolics' a => a c -> SymbolSet
extractSymbolics' b c
r

instance
  (ExtractSymbolics' a, ExtractSymbolics' b) =>
  ExtractSymbolics' (a :*: b)
  where
  extractSymbolics' :: forall c. (:*:) a b c -> SymbolSet
extractSymbolics' (a c
l :*: b c
r) = a c -> SymbolSet
forall (a :: * -> *) c. ExtractSymbolics' a => a c -> SymbolSet
extractSymbolics' a c
l SymbolSet -> SymbolSet -> SymbolSet
forall a. Semigroup a => a -> a -> a
<> b c -> SymbolSet
forall (a :: * -> *) c. ExtractSymbolics' a => a c -> SymbolSet
extractSymbolics' b c
r

-- instances

#define CONCRETE_EXTRACT_SYMBOLICS(type) \
instance  ExtractSymbolics type where \
  extractSymbolics _ = mempty

#if 1
CONCRETE_EXTRACT_SYMBOLICS(Bool)
CONCRETE_EXTRACT_SYMBOLICS(Integer)
CONCRETE_EXTRACT_SYMBOLICS(Char)
CONCRETE_EXTRACT_SYMBOLICS(Int)
CONCRETE_EXTRACT_SYMBOLICS(Int8)
CONCRETE_EXTRACT_SYMBOLICS(Int16)
CONCRETE_EXTRACT_SYMBOLICS(Int32)
CONCRETE_EXTRACT_SYMBOLICS(Int64)
CONCRETE_EXTRACT_SYMBOLICS(Word)
CONCRETE_EXTRACT_SYMBOLICS(Word8)
CONCRETE_EXTRACT_SYMBOLICS(Word16)
CONCRETE_EXTRACT_SYMBOLICS(Word32)
CONCRETE_EXTRACT_SYMBOLICS(Word64)
CONCRETE_EXTRACT_SYMBOLICS(B.ByteString)
#endif

-- ()
instance ExtractSymbolics () where
  extractSymbolics :: () -> SymbolSet
extractSymbolics ()
_ = SymbolSet
forall a. Monoid a => a
mempty

-- Either
deriving via
  (Default (Either a b))
  instance
    (ExtractSymbolics a, ExtractSymbolics b) =>
    ExtractSymbolics (Either a b)

-- Maybe
deriving via
  (Default (Maybe a))
  instance
    (ExtractSymbolics a) => ExtractSymbolics (Maybe a)

-- List
deriving via
  (Default [a])
  instance
    (ExtractSymbolics a) => ExtractSymbolics [a]

-- (,)
deriving via
  (Default (a, b))
  instance
    (ExtractSymbolics a, ExtractSymbolics b) =>
    ExtractSymbolics (a, b)

-- (,,)
deriving via
  (Default (a, b, c))
  instance
    (ExtractSymbolics a, ExtractSymbolics b, ExtractSymbolics c) =>
    ExtractSymbolics (a, b, c)

-- MaybeT
instance (ExtractSymbolics (m (Maybe a))) => ExtractSymbolics (MaybeT m a) where
  extractSymbolics :: MaybeT m a -> SymbolSet
extractSymbolics (MaybeT m (Maybe a)
v) = m (Maybe a) -> SymbolSet
forall a. ExtractSymbolics a => a -> SymbolSet
extractSymbolics m (Maybe a)
v

-- ExceptT
instance
  (ExtractSymbolics (m (Either e a))) =>
  ExtractSymbolics (ExceptT e m a)
  where
  extractSymbolics :: ExceptT e m a -> SymbolSet
extractSymbolics (ExceptT m (Either e a)
v) = m (Either e a) -> SymbolSet
forall a. ExtractSymbolics a => a -> SymbolSet
extractSymbolics m (Either e a)
v

-- Sum
deriving via
  (Default (Sum f g a))
  instance
    (ExtractSymbolics (f a), ExtractSymbolics (g a)) =>
    ExtractSymbolics (Sum f g a)

-- WriterT
instance
  (ExtractSymbolics (m (a, s))) =>
  ExtractSymbolics (WriterLazy.WriterT s m a)
  where
  extractSymbolics :: WriterT s m a -> SymbolSet
extractSymbolics (WriterLazy.WriterT m (a, s)
f) = m (a, s) -> SymbolSet
forall a. ExtractSymbolics a => a -> SymbolSet
extractSymbolics m (a, s)
f

instance
  (ExtractSymbolics (m (a, s))) =>
  ExtractSymbolics (WriterStrict.WriterT s m a)
  where
  extractSymbolics :: WriterT s m a -> SymbolSet
extractSymbolics (WriterStrict.WriterT m (a, s)
f) = m (a, s) -> SymbolSet
forall a. ExtractSymbolics a => a -> SymbolSet
extractSymbolics m (a, s)
f

-- Identity
instance (ExtractSymbolics a) => ExtractSymbolics (Identity a) where
  extractSymbolics :: Identity a -> SymbolSet
extractSymbolics (Identity a
a) = a -> SymbolSet
forall a. ExtractSymbolics a => a -> SymbolSet
extractSymbolics a
a

-- IdentityT
instance (ExtractSymbolics (m a)) => ExtractSymbolics (IdentityT m a) where
  extractSymbolics :: IdentityT m a -> SymbolSet
extractSymbolics (IdentityT m a
a) = m a -> SymbolSet
forall a. ExtractSymbolics a => a -> SymbolSet
extractSymbolics m a
a