{-# LANGUAGE CPP #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}

-- |
-- Module      :   Grisette.Internal.SymPrim.AllSyms
-- Copyright   :   (c) Sirui Lu 2024
-- License     :   BSD-3-Clause (see the LICENSE file)
--
-- Maintainer  :   siruilu@cs.washington.edu
-- Stability   :   Experimental
-- Portability :   GHC only
module Grisette.Internal.SymPrim.AllSyms
  ( symSize,
    symsSize,
    SomeSym (..),
    AllSyms (..),
    allSymsSize,
  )
where

import Control.Monad.Except (ExceptT (ExceptT))
import Control.Monad.Identity
  ( Identity (Identity),
    IdentityT (IdentityT),
  )
import Control.Monad.Trans.Maybe (MaybeT (MaybeT))
import qualified Control.Monad.Writer.Lazy as WriterLazy
import qualified Control.Monad.Writer.Strict as WriterStrict
import qualified Data.ByteString as B
import Data.Functor.Sum (Sum)
import Data.Int (Int16, Int32, Int64, Int8)
import qualified Data.Text as T
import Data.Word (Word16, Word32, Word64, Word8)
import GHC.Generics
  ( Generic (Rep, from),
    K1 (K1),
    M1 (M1),
    U1,
    type (:*:) ((:*:)),
    type (:+:) (L1, R1),
  )
import Generics.Deriving (Default (Default, unDefault))
import Grisette.Internal.Core.Control.Exception
  ( AssertionError,
    VerificationConditions,
  )
import Grisette.Internal.SymPrim.Prim.SomeTerm
  ( SomeTerm (SomeTerm),
  )
import Grisette.Internal.SymPrim.Prim.Term
  ( LinkedRep (underlyingTerm),
  )
import Grisette.Internal.SymPrim.Prim.TermUtils
  ( someTermsSize,
    termSize,
    termsSize,
  )

-- $setup
-- >>> import Grisette.Core
-- >>> import Grisette.SymPrim
-- >>> import Grisette.Backend
-- >>> import Data.Proxy

-- | Get the sum of the sizes of a list of symbolic terms.
-- Duplicate sub-terms are counted for only once.
--
-- >>> symsSize [1, "a" :: SymInteger, "a" + 1 :: SymInteger]
-- 3
symsSize :: forall con sym. (LinkedRep con sym) => [sym] -> Int
symsSize :: forall con sym. LinkedRep con sym => [sym] -> Int
symsSize = [Term con] -> Int
forall a. [Term a] -> Int
termsSize ([Term con] -> Int) -> ([sym] -> [Term con]) -> [sym] -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (sym -> Term con) -> [sym] -> [Term con]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall con sym. LinkedRep con sym => sym -> Term con
underlyingTerm @con)
{-# INLINE symsSize #-}

-- | Get the size of a symbolic term.
-- Duplicate sub-terms are counted for only once.
--
-- >>> symSize (1 :: SymInteger)
-- 1
-- >>> symSize ("a" :: SymInteger)
-- 1
-- >>> symSize ("a" + 1 :: SymInteger)
-- 3
-- >>> symSize (("a" + 1) * ("a" + 1) :: SymInteger)
-- 4
symSize :: forall con sym. (LinkedRep con sym) => sym -> Int
symSize :: forall con sym. LinkedRep con sym => sym -> Int
symSize = Term con -> Int
forall a. Term a -> Int
termSize (Term con -> Int) -> (sym -> Term con) -> sym -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall con sym. LinkedRep con sym => sym -> Term con
underlyingTerm @con
{-# INLINE symSize #-}

-- | Some symbolic value with 'LinkedRep' constraint.
data SomeSym where
  SomeSym :: (LinkedRep con sym) => sym -> SomeSym

someUnderlyingTerm :: SomeSym -> SomeTerm
someUnderlyingTerm :: SomeSym -> SomeTerm
someUnderlyingTerm (SomeSym sym
s) = Term con -> SomeTerm
forall a. SupportedPrim a => Term a -> SomeTerm
SomeTerm (Term con -> SomeTerm) -> Term con -> SomeTerm
forall a b. (a -> b) -> a -> b
$ sym -> Term con
forall con sym. LinkedRep con sym => sym -> Term con
underlyingTerm sym
s

someSymsSize :: [SomeSym] -> Int
someSymsSize :: [SomeSym] -> Int
someSymsSize = [SomeTerm] -> Int
someTermsSize ([SomeTerm] -> Int)
-> ([SomeSym] -> [SomeTerm]) -> [SomeSym] -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SomeSym -> SomeTerm) -> [SomeSym] -> [SomeTerm]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SomeSym -> SomeTerm
someUnderlyingTerm
{-# INLINE someSymsSize #-}

-- | Extract all symbolic primitive values that are represented as SMT terms.
--
-- __Note:__ This type class can be derived for algebraic data types. You may
-- need the @DerivingVia@ and @DerivingStrategies@ extenstions.
--
-- > data X = ... deriving Generic deriving AllSyms via (Default X)
class AllSyms a where
  -- | Convert a value to a list of symbolic primitive values. It should
  -- prepend to an existing list of symbolic primitive values.
  allSymsS :: a -> [SomeSym] -> [SomeSym]
  allSymsS a
a [SomeSym]
l = a -> [SomeSym]
forall a. AllSyms a => a -> [SomeSym]
allSyms a
a [SomeSym] -> [SomeSym] -> [SomeSym]
forall a. [a] -> [a] -> [a]
++ [SomeSym]
l

  -- | Specialized 'allSymsS' that prepends to an empty list.
  allSyms :: a -> [SomeSym]
  allSyms a
a = a -> [SomeSym] -> [SomeSym]
forall a. AllSyms a => a -> [SomeSym] -> [SomeSym]
allSymsS a
a []

  {-# MINIMAL allSymsS | allSyms #-}

-- | Get the total size of symbolic terms in a value.
-- Duplicate sub-terms are counted for only once.
--
-- >>> allSymsSize ("a" :: SymInteger, "a" + "b" :: SymInteger, ("a" + "b") * "c" :: SymInteger)
-- 5
allSymsSize :: (AllSyms a) => a -> Int
allSymsSize :: forall a. AllSyms a => a -> Int
allSymsSize = [SomeSym] -> Int
someSymsSize ([SomeSym] -> Int) -> (a -> [SomeSym]) -> a -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> [SomeSym]
forall a. AllSyms a => a -> [SomeSym]
allSyms

class AllSyms' a where
  allSymsS' :: a c -> [SomeSym] -> [SomeSym]

instance (Generic a, AllSyms' (Rep a)) => AllSyms (Default a) where
  allSymsS :: Default a -> [SomeSym] -> [SomeSym]
allSymsS = Rep a Any -> [SomeSym] -> [SomeSym]
forall c. Rep a c -> [SomeSym] -> [SomeSym]
forall (a :: * -> *) c. AllSyms' a => a c -> [SomeSym] -> [SomeSym]
allSymsS' (Rep a Any -> [SomeSym] -> [SomeSym])
-> (Default a -> Rep a Any) -> Default a -> [SomeSym] -> [SomeSym]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Rep a Any
forall x. a -> Rep a x
forall a x. Generic a => a -> Rep a x
from (a -> Rep a Any) -> (Default a -> a) -> Default a -> Rep a Any
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Default a -> a
forall a. Default a -> a
unDefault

instance AllSyms' U1 where
  allSymsS' :: forall c. U1 c -> [SomeSym] -> [SomeSym]
allSymsS' U1 c
_ = [SomeSym] -> [SomeSym]
forall a. a -> a
id

instance (AllSyms c) => AllSyms' (K1 i c) where
  allSymsS' :: forall c. K1 i c c -> [SomeSym] -> [SomeSym]
allSymsS' (K1 c
v) = c -> [SomeSym] -> [SomeSym]
forall a. AllSyms a => a -> [SomeSym] -> [SomeSym]
allSymsS c
v

instance (AllSyms' a) => AllSyms' (M1 i c a) where
  allSymsS' :: forall c. M1 i c a c -> [SomeSym] -> [SomeSym]
allSymsS' (M1 a c
v) = a c -> [SomeSym] -> [SomeSym]
forall c. a c -> [SomeSym] -> [SomeSym]
forall (a :: * -> *) c. AllSyms' a => a c -> [SomeSym] -> [SomeSym]
allSymsS' a c
v

instance (AllSyms' a, AllSyms' b) => AllSyms' (a :+: b) where
  allSymsS' :: forall c. (:+:) a b c -> [SomeSym] -> [SomeSym]
allSymsS' (L1 a c
l) = a c -> [SomeSym] -> [SomeSym]
forall c. a c -> [SomeSym] -> [SomeSym]
forall (a :: * -> *) c. AllSyms' a => a c -> [SomeSym] -> [SomeSym]
allSymsS' a c
l
  allSymsS' (R1 b c
r) = b c -> [SomeSym] -> [SomeSym]
forall c. b c -> [SomeSym] -> [SomeSym]
forall (a :: * -> *) c. AllSyms' a => a c -> [SomeSym] -> [SomeSym]
allSymsS' b c
r

instance (AllSyms' a, AllSyms' b) => AllSyms' (a :*: b) where
  allSymsS' :: forall c. (:*:) a b c -> [SomeSym] -> [SomeSym]
allSymsS' (a c
a :*: b c
b) = a c -> [SomeSym] -> [SomeSym]
forall c. a c -> [SomeSym] -> [SomeSym]
forall (a :: * -> *) c. AllSyms' a => a c -> [SomeSym] -> [SomeSym]
allSymsS' a c
a ([SomeSym] -> [SomeSym])
-> ([SomeSym] -> [SomeSym]) -> [SomeSym] -> [SomeSym]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. b c -> [SomeSym] -> [SomeSym]
forall c. b c -> [SomeSym] -> [SomeSym]
forall (a :: * -> *) c. AllSyms' a => a c -> [SomeSym] -> [SomeSym]
allSymsS' b c
b

#define CONCRETE_ALLSYMS(type) \
instance AllSyms type where \
  allSymsS _ = id

#if 1
CONCRETE_ALLSYMS(Bool)
CONCRETE_ALLSYMS(Integer)
CONCRETE_ALLSYMS(Char)
CONCRETE_ALLSYMS(Int)
CONCRETE_ALLSYMS(Int8)
CONCRETE_ALLSYMS(Int16)
CONCRETE_ALLSYMS(Int32)
CONCRETE_ALLSYMS(Int64)
CONCRETE_ALLSYMS(Word)
CONCRETE_ALLSYMS(Word8)
CONCRETE_ALLSYMS(Word16)
CONCRETE_ALLSYMS(Word32)
CONCRETE_ALLSYMS(Word64)
CONCRETE_ALLSYMS(B.ByteString)
CONCRETE_ALLSYMS(T.Text)
#endif

instance AllSyms () where
  allSymsS :: () -> [SomeSym] -> [SomeSym]
allSymsS ()
_ = [SomeSym] -> [SomeSym]
forall a. a -> a
id

-- Either
deriving via
  (Default (Either a b))
  instance
    ( AllSyms a,
      AllSyms b
    ) =>
    AllSyms (Either a b)

-- Maybe
deriving via (Default (Maybe a)) instance (AllSyms a) => AllSyms (Maybe a)

-- List
deriving via (Default [a]) instance (AllSyms a) => AllSyms [a]

-- (,)
deriving via
  (Default (a, b))
  instance
    (AllSyms a, AllSyms b) =>
    AllSyms (a, b)

-- (,,)
deriving via
  (Default (a, b, c))
  instance
    ( AllSyms a,
      AllSyms b,
      AllSyms c
    ) =>
    AllSyms (a, b, c)

-- (,,,)
deriving via
  (Default (a, b, c, d))
  instance
    ( AllSyms a,
      AllSyms b,
      AllSyms c,
      AllSyms d
    ) =>
    AllSyms (a, b, c, d)

-- (,,,,)
deriving via
  (Default (a, b, c, d, e))
  instance
    ( AllSyms a,
      AllSyms b,
      AllSyms c,
      AllSyms d,
      AllSyms e
    ) =>
    AllSyms (a, b, c, d, e)

-- (,,,,,)
deriving via
  (Default (a, b, c, d, e, f))
  instance
    ( AllSyms a,
      AllSyms b,
      AllSyms c,
      AllSyms d,
      AllSyms e,
      AllSyms f
    ) =>
    AllSyms (a, b, c, d, e, f)

-- (,,,,,,)
deriving via
  (Default (a, b, c, d, e, f, g))
  instance
    ( AllSyms a,
      AllSyms b,
      AllSyms c,
      AllSyms d,
      AllSyms e,
      AllSyms f,
      AllSyms g
    ) =>
    AllSyms (a, b, c, d, e, f, g)

-- (,,,,,,,)
deriving via
  (Default (a, b, c, d, e, f, g, h))
  instance
    ( AllSyms a,
      AllSyms b,
      AllSyms c,
      AllSyms d,
      AllSyms e,
      AllSyms f,
      AllSyms g,
      AllSyms h
    ) =>
    AllSyms ((,,,,,,,) a b c d e f g h)

-- MaybeT
instance
  (AllSyms (m (Maybe a))) =>
  AllSyms (MaybeT m a)
  where
  allSymsS :: MaybeT m a -> [SomeSym] -> [SomeSym]
allSymsS (MaybeT m (Maybe a)
v) = m (Maybe a) -> [SomeSym] -> [SomeSym]
forall a. AllSyms a => a -> [SomeSym] -> [SomeSym]
allSymsS m (Maybe a)
v

-- ExceptT
instance
  (AllSyms (m (Either e a))) =>
  AllSyms (ExceptT e m a)
  where
  allSymsS :: ExceptT e m a -> [SomeSym] -> [SomeSym]
allSymsS (ExceptT m (Either e a)
v) = m (Either e a) -> [SomeSym] -> [SomeSym]
forall a. AllSyms a => a -> [SomeSym] -> [SomeSym]
allSymsS m (Either e a)
v

-- Sum
deriving via
  (Default (Sum f g a))
  instance
    (AllSyms (f a), AllSyms (g a)) =>
    AllSyms (Sum f g a)

-- WriterT
instance
  (AllSyms (m (a, s))) =>
  AllSyms (WriterLazy.WriterT s m a)
  where
  allSymsS :: WriterT s m a -> [SomeSym] -> [SomeSym]
allSymsS (WriterLazy.WriterT m (a, s)
v) = m (a, s) -> [SomeSym] -> [SomeSym]
forall a. AllSyms a => a -> [SomeSym] -> [SomeSym]
allSymsS m (a, s)
v

instance
  (AllSyms (m (a, s))) =>
  AllSyms (WriterStrict.WriterT s m a)
  where
  allSymsS :: WriterT s m a -> [SomeSym] -> [SomeSym]
allSymsS (WriterStrict.WriterT m (a, s)
v) = m (a, s) -> [SomeSym] -> [SomeSym]
forall a. AllSyms a => a -> [SomeSym] -> [SomeSym]
allSymsS m (a, s)
v

-- Identity
instance (AllSyms a) => AllSyms (Identity a) where
  allSymsS :: Identity a -> [SomeSym] -> [SomeSym]
allSymsS (Identity a
a) = a -> [SomeSym] -> [SomeSym]
forall a. AllSyms a => a -> [SomeSym] -> [SomeSym]
allSymsS a
a

-- IdentityT
instance (AllSyms (m a)) => AllSyms (IdentityT m a) where
  allSymsS :: IdentityT m a -> [SomeSym] -> [SomeSym]
allSymsS (IdentityT m a
a) = m a -> [SomeSym] -> [SomeSym]
forall a. AllSyms a => a -> [SomeSym] -> [SomeSym]
allSymsS m a
a

-- VerificationConditions
deriving via (Default VerificationConditions) instance AllSyms VerificationConditions

-- AssertionError
deriving via (Default AssertionError) instance AllSyms AssertionError