{-# LANGUAGE FlexibleContexts #-}

-- |
-- Module      :   Grisette.Lib.Data.List
-- Copyright   :   (c) Sirui Lu 2021-2023
-- License     :   BSD-3-Clause (see the LICENSE file)
--
-- Maintainer  :   siruilu@cs.washington.edu
-- Stability   :   Experimental
-- Portability :   GHC only
module Grisette.Lib.Data.List
  ( -- * Symbolic versions of 'Data.List' operations
    (.!!),
    symFilter,
    symTake,
    symDrop,
  )
where

import Control.Exception (ArrayException (IndexOutOfBounds))
import Control.Monad.Except (MonadError (throwError))
import Grisette.Core.Control.Monad.Union (MonadUnion)
import Grisette.Core.Data.Class.Error (TransformError (transformError))
import Grisette.Core.Data.Class.Mergeable (Mergeable)
import Grisette.Core.Data.Class.SEq (SEq ((.==)))
import Grisette.Core.Data.Class.SOrd (SOrd ((.<=)))
import Grisette.Core.Data.Class.SimpleMergeable (mrgIf)
import Grisette.IR.SymPrim.Data.SymPrim (SymBool, SymInteger)
import Grisette.Lib.Control.Monad (mrgFmap, mrgReturn)

-- | Symbolic version of 'Data.List.!!', the result would be merged and
-- propagate the mergeable knowledge.
(.!!) ::
  ( MonadUnion uf,
    MonadError e uf,
    TransformError ArrayException e,
    Mergeable a
  ) =>
  [a] ->
  SymInteger ->
  uf a
[a]
l .!! :: forall (uf :: * -> *) e a.
(MonadUnion uf, MonadError e uf, TransformError ArrayException e,
 Mergeable a) =>
[a] -> SymInteger -> uf a
.!! SymInteger
p = [a] -> SymInteger -> SymInteger -> uf a
forall {e} {m :: * -> *} {a} {a}.
(MonadError e m, TransformError ArrayException e, UnionLike m,
 Mergeable a, SEq a, Num a) =>
[a] -> a -> a -> m a
go [a]
l SymInteger
p SymInteger
0
  where
    go :: [a] -> a -> a -> m a
go [] a
_ a
_ = e -> m a
forall a. e -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (e -> m a) -> e -> m a
forall a b. (a -> b) -> a -> b
$ ArrayException -> e
forall from to. TransformError from to => from -> to
transformError (String -> ArrayException
IndexOutOfBounds String
"!!~")
    go (a
x : [a]
xs) a
p1 a
i = SymBool -> m a -> m a -> m a
forall (u :: * -> *) a.
(UnionLike u, Mergeable a) =>
SymBool -> u a -> u a -> u a
mrgIf (a
p1 a -> a -> SymBool
forall a. SEq a => a -> a -> SymBool
.== a
i) (a -> m a
forall (u :: * -> *) a. (MonadUnion u, Mergeable a) => a -> u a
mrgReturn a
x) ([a] -> a -> a -> m a
go [a]
xs a
p1 (a -> m a) -> a -> m a
forall a b. (a -> b) -> a -> b
$ a
i a -> a -> a
forall a. Num a => a -> a -> a
+ a
1)

-- | Symbolic version of 'Data.List.filter', the result would be merged and
-- propagate the mergeable knowledge.
symFilter :: (MonadUnion u, Mergeable a) => (a -> SymBool) -> [a] -> u [a]
symFilter :: forall (u :: * -> *) a.
(MonadUnion u, Mergeable a) =>
(a -> SymBool) -> [a] -> u [a]
symFilter a -> SymBool
f = [a] -> u [a]
forall {u :: * -> *}. (UnionLike u, Monad u) => [a] -> u [a]
go
  where
    go :: [a] -> u [a]
go [] = [a] -> u [a]
forall (u :: * -> *) a. (MonadUnion u, Mergeable a) => a -> u a
mrgReturn []
    go (a
x : [a]
xs) = do
      [a]
r <- [a] -> u [a]
go [a]
xs
      SymBool -> u [a] -> u [a] -> u [a]
forall (u :: * -> *) a.
(UnionLike u, Mergeable a) =>
SymBool -> u a -> u a -> u a
mrgIf (a -> SymBool
f a
x) ([a] -> u [a]
forall (u :: * -> *) a. (MonadUnion u, Mergeable a) => a -> u a
mrgReturn (a
x a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [a]
r)) ([a] -> u [a]
forall (u :: * -> *) a. (MonadUnion u, Mergeable a) => a -> u a
mrgReturn [a]
r)

-- | Symbolic version of 'Data.List.take', the result would be merged and
-- propagate the mergeable knowledge.
symTake :: (MonadUnion u, Mergeable a) => SymInteger -> [a] -> u [a]
symTake :: forall (u :: * -> *) a.
(MonadUnion u, Mergeable a) =>
SymInteger -> [a] -> u [a]
symTake SymInteger
_ [] = [a] -> u [a]
forall (u :: * -> *) a. (MonadUnion u, Mergeable a) => a -> u a
mrgReturn []
symTake SymInteger
x (a
v : [a]
vs) = SymBool -> u [a] -> u [a] -> u [a]
forall (u :: * -> *) a.
(UnionLike u, Mergeable a) =>
SymBool -> u a -> u a -> u a
mrgIf (SymInteger
x SymInteger -> SymInteger -> SymBool
forall a. SOrd a => a -> a -> SymBool
.<= SymInteger
0) ([a] -> u [a]
forall (u :: * -> *) a. (MonadUnion u, Mergeable a) => a -> u a
mrgReturn []) (([a] -> [a]) -> u [a] -> u [a]
forall (f :: * -> *) b a.
(MonadUnion f, Mergeable b, Functor f) =>
(a -> b) -> f a -> f b
mrgFmap (a
v a -> [a] -> [a]
forall a. a -> [a] -> [a]
:) (u [a] -> u [a]) -> u [a] -> u [a]
forall a b. (a -> b) -> a -> b
$ SymInteger -> [a] -> u [a]
forall (u :: * -> *) a.
(MonadUnion u, Mergeable a) =>
SymInteger -> [a] -> u [a]
symTake (SymInteger
x SymInteger -> SymInteger -> SymInteger
forall a. Num a => a -> a -> a
- SymInteger
1) [a]
vs)

-- | Symbolic version of 'Data.List.drop', the result would be merged and
-- propagate the mergeable knowledge.
symDrop :: (MonadUnion u, Mergeable a) => SymInteger -> [a] -> u [a]
symDrop :: forall (u :: * -> *) a.
(MonadUnion u, Mergeable a) =>
SymInteger -> [a] -> u [a]
symDrop SymInteger
_ [] = [a] -> u [a]
forall (u :: * -> *) a. (MonadUnion u, Mergeable a) => a -> u a
mrgReturn []
symDrop SymInteger
x r :: [a]
r@(a
_ : [a]
vs) = SymBool -> u [a] -> u [a] -> u [a]
forall (u :: * -> *) a.
(UnionLike u, Mergeable a) =>
SymBool -> u a -> u a -> u a
mrgIf (SymInteger
x SymInteger -> SymInteger -> SymBool
forall a. SOrd a => a -> a -> SymBool
.<= SymInteger
0) ([a] -> u [a]
forall (u :: * -> *) a. (MonadUnion u, Mergeable a) => a -> u a
mrgReturn [a]
r) (SymInteger -> [a] -> u [a]
forall (u :: * -> *) a.
(MonadUnion u, Mergeable a) =>
SymInteger -> [a] -> u [a]
symDrop (SymInteger
x SymInteger -> SymInteger -> SymInteger
forall a. Num a => a -> a -> a
- SymInteger
1) [a]
vs)