{-# LANGUAGE NondecreasingIndentation #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts, FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE TypeOperators #-}
module What4.Utils.Serialize
    (
      withRounding
    , makeSymbol
    , asyncLinked
    , withAsyncLinked
    ) where

import qualified Control.Exception as E
import           Text.Printf ( printf )
import qualified Data.BitVector.Sized as BV
import           What4.BaseTypes
import qualified What4.Interface as S
import           What4.Symbol ( SolverSymbol, userSymbol )


import qualified UnliftIO as U

----------------------------------------------------------------
-- * Async

-- | Fork an async action that is linked to the parent thread, but can
-- be safely 'U.cancel'd without also killing the parent thread.
--
-- Note that if your async doesn't return unit, then you probably want
-- to 'U.wait' for it instead, which eliminates the need for linking
-- it. Also, if you plan to cancel the async near where you fork it,
-- then 'withAsyncLinked' is a better choice than using this function
-- and subsequently canceling, since it ensures cancellation.
--
-- See https://github.com/simonmar/async/issues/25 for a perhaps more
-- robust, but also harder to use version of this. The linked version
-- is harder to use because it requires a special version of @cancel@.
asyncLinked :: (U.MonadUnliftIO m) => m () -> m (U.Async ())
asyncLinked :: forall (m :: Type -> Type). MonadUnliftIO m => m () -> m (Async ())
asyncLinked m ()
action = do
  -- We use 'U.mask' to avoid a race condition between starting the
  -- async and running @action@. Without 'U.mask' here, an async
  -- exception (e.g. via 'U.cancel') could arrive after
  -- @handleUnliftIO@ starts to run but before @action@ starts.
  forall (m :: Type -> Type) b.
MonadUnliftIO m =>
((forall a. m a -> m a) -> m b) -> m b
U.mask forall a b. (a -> b) -> a -> b
$ \forall a. m a -> m a
restore -> do
  Async ()
a <- forall (m :: Type -> Type) a. MonadUnliftIO m => m a -> m (Async a)
U.async forall a b. (a -> b) -> a -> b
$ forall (m :: Type -> Type) e a.
(MonadUnliftIO m, Exception e) =>
(e -> m a) -> m a -> m a
handleUnliftIO forall (m :: Type -> Type). Monad m => AsyncException -> m ()
threadKilledHandler (forall a. m a -> m a
restore m ()
action)
  forall a. m a -> m a
restore forall a b. (a -> b) -> a -> b
$ do
  forall (m :: Type -> Type) a. MonadIO m => Async a -> m ()
U.link Async ()
a
  forall (m :: Type -> Type) a. Monad m => a -> m a
return Async ()
a

-- | Handle asynchronous 'E.ThreadKilled' exceptions without killing the parent
-- thread. All other forms of asynchronous exceptions are rethrown.
threadKilledHandler :: Monad m => E.AsyncException -> m ()
threadKilledHandler :: forall (m :: Type -> Type). Monad m => AsyncException -> m ()
threadKilledHandler AsyncException
E.ThreadKilled = forall (m :: Type -> Type) a. Monad m => a -> m a
return ()
threadKilledHandler AsyncException
e              = forall a e. Exception e => e -> a
E.throw AsyncException
e

-- | A version of 'U.withAsync' that safely links the child. See
-- 'asyncLinked'.
withAsyncLinked :: (U.MonadUnliftIO m) => m () -> (U.Async () -> m a) -> m a
withAsyncLinked :: forall (m :: Type -> Type) a.
MonadUnliftIO m =>
m () -> (Async () -> m a) -> m a
withAsyncLinked m ()
child Async () -> m a
parent = do
  forall (m :: Type -> Type) b.
MonadUnliftIO m =>
((forall a. m a -> m a) -> m b) -> m b
U.mask forall a b. (a -> b) -> a -> b
$ \forall a. m a -> m a
restore -> do
  forall (m :: Type -> Type) a b.
MonadUnliftIO m =>
m a -> (Async a -> m b) -> m b
U.withAsync (forall (m :: Type -> Type) e a.
(MonadUnliftIO m, Exception e) =>
(e -> m a) -> m a -> m a
handleUnliftIO forall (m :: Type -> Type). Monad m => AsyncException -> m ()
threadKilledHandler forall a b. (a -> b) -> a -> b
$ forall a. m a -> m a
restore m ()
child) forall a b. (a -> b) -> a -> b
$ \Async ()
a -> forall a. m a -> m a
restore forall a b. (a -> b) -> a -> b
$ do
  forall (m :: Type -> Type) a. MonadIO m => Async a -> m ()
U.link Async ()
a
  Async () -> m a
parent Async ()
a

-- A 'U.MonadUnliftIO' version of 'Control.Exception.handle'.
--
-- The 'U.handle' doesn't catch async exceptions, because the
-- @unliftio@ library uses the @safe-execeptions@ library, not
-- @base@, for it exception handling primitives. This is very
-- confusing if you're not expecting it!
handleUnliftIO :: (U.MonadUnliftIO m, U.Exception e)
               => (e -> m a) -> m a -> m a
handleUnliftIO :: forall (m :: Type -> Type) e a.
(MonadUnliftIO m, Exception e) =>
(e -> m a) -> m a -> m a
handleUnliftIO e -> m a
h m a
a = forall (m :: Type -> Type) a.
MonadUnliftIO m =>
(UnliftIO m -> IO a) -> m a
U.withUnliftIO forall a b. (a -> b) -> a -> b
$ \UnliftIO m
u ->
  forall e a. Exception e => (e -> IO a) -> IO a -> IO a
E.handle (forall (m :: Type -> Type). UnliftIO m -> forall a. m a -> IO a
U.unliftIO UnliftIO m
u forall b c a. (b -> c) -> (a -> b) -> a -> c
. e -> m a
h) (forall (m :: Type -> Type). UnliftIO m -> forall a. m a -> IO a
U.unliftIO UnliftIO m
u m a
a)

-- | Try converting any 'String' into a 'SolverSymbol'. If it is an invalid
-- symbol, then error.
makeSymbol :: String -> SolverSymbol
makeSymbol :: String -> SolverSymbol
makeSymbol String
name = case String -> Either SolverSymbolError SolverSymbol
userSymbol String
sanitizedName of
                    Right SolverSymbol
symbol -> SolverSymbol
symbol
                    Left SolverSymbolError
_ -> forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ forall r. PrintfType r => String -> r
printf String
"tried to create symbol with bad name: %s (%s)"
                                             String
name String
sanitizedName
  where
    -- We use a custom name sanitizer here because downstream clients may depend
    -- on the format of the name. It would be nice to use 'safeSymbol' here, but
    -- it mangles names with z-encoding in a way that might be unusable
    -- downstream.
    sanitizedName :: String
sanitizedName = forall a b. (a -> b) -> [a] -> [b]
map (\Char
c -> case Char
c of Char
' ' -> Char
'_'; Char
'.' -> Char
'_'; Char
_ -> Char
c) String
name

withRounding
  :: forall sym tp
   . S.IsExprBuilder sym
  => sym
  -> S.SymBV sym 2
  -> (S.RoundingMode -> IO (S.SymExpr sym tp))
  -> IO (S.SymExpr sym tp)
withRounding :: forall sym (tp :: BaseType).
IsExprBuilder sym =>
sym
-> SymBV sym 2
-> (RoundingMode -> IO (SymExpr sym tp))
-> IO (SymExpr sym tp)
withRounding sym
sym SymBV sym 2
r RoundingMode -> IO (SymExpr sym tp)
action = do
  SymExpr sym BaseBoolType
cRNE <- RoundingMode -> IO (SymExpr sym BaseBoolType)
roundingCond RoundingMode
S.RNE
  SymExpr sym BaseBoolType
cRTZ <- RoundingMode -> IO (SymExpr sym BaseBoolType)
roundingCond RoundingMode
S.RTZ
  SymExpr sym BaseBoolType
cRTP <- RoundingMode -> IO (SymExpr sym BaseBoolType)
roundingCond RoundingMode
S.RTP
  forall sym v.
IsExprBuilder sym =>
(sym -> Pred sym -> v -> v -> IO v)
-> sym -> Pred sym -> IO v -> IO v -> IO v
S.iteM forall sym (tp :: BaseType).
IsExprBuilder sym =>
sym
-> Pred sym
-> SymExpr sym tp
-> SymExpr sym tp
-> IO (SymExpr sym tp)
S.baseTypeIte sym
sym SymExpr sym BaseBoolType
cRNE
    (RoundingMode -> IO (SymExpr sym tp)
action RoundingMode
S.RNE) forall a b. (a -> b) -> a -> b
$
    forall sym v.
IsExprBuilder sym =>
(sym -> Pred sym -> v -> v -> IO v)
-> sym -> Pred sym -> IO v -> IO v -> IO v
S.iteM forall sym (tp :: BaseType).
IsExprBuilder sym =>
sym
-> Pred sym
-> SymExpr sym tp
-> SymExpr sym tp
-> IO (SymExpr sym tp)
S.baseTypeIte sym
sym SymExpr sym BaseBoolType
cRTZ
      (RoundingMode -> IO (SymExpr sym tp)
action RoundingMode
S.RTZ) forall a b. (a -> b) -> a -> b
$
      forall sym v.
IsExprBuilder sym =>
(sym -> Pred sym -> v -> v -> IO v)
-> sym -> Pred sym -> IO v -> IO v -> IO v
S.iteM forall sym (tp :: BaseType).
IsExprBuilder sym =>
sym
-> Pred sym
-> SymExpr sym tp
-> SymExpr sym tp
-> IO (SymExpr sym tp)
S.baseTypeIte sym
sym SymExpr sym BaseBoolType
cRTP (RoundingMode -> IO (SymExpr sym tp)
action RoundingMode
S.RTP) (RoundingMode -> IO (SymExpr sym tp)
action RoundingMode
S.RTN)
 where
  roundingCond :: S.RoundingMode -> IO (S.Pred sym)
  roundingCond :: RoundingMode -> IO (SymExpr sym BaseBoolType)
roundingCond RoundingMode
rm =
    forall sym (w :: Natural).
(IsExprBuilder sym, 1 <= w) =>
sym -> SymBV sym w -> SymBV sym w -> IO (Pred sym)
S.bvEq sym
sym SymBV sym 2
r forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall sym (w :: Natural).
(IsExprBuilder sym, 1 <= w) =>
sym -> NatRepr w -> BV w -> IO (SymBV sym w)
S.bvLit sym
sym forall (n :: Natural). KnownNat n => NatRepr n
knownNat (forall (w :: Natural). NatRepr w -> Integer -> BV w
BV.mkBV forall (n :: Natural). KnownNat n => NatRepr n
knownNat (RoundingMode -> Integer
roundingModeToBits RoundingMode
rm))

roundingModeToBits :: S.RoundingMode -> Integer
roundingModeToBits :: RoundingMode -> Integer
roundingModeToBits = \case
  RoundingMode
S.RNE -> Integer
0
  RoundingMode
S.RTZ -> Integer
1
  RoundingMode
S.RTP -> Integer
2
  RoundingMode
S.RTN -> Integer
3
  RoundingMode
S.RNA -> forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"unsupported rounding mode: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show RoundingMode
S.RNA