{- |
Module      : Data.LLVM.BitCode.Assert
Description : This module implements exceptions and warnings about bitcode.
License     : BSD3
Maintainer  : lbarrett
Stability   : experimental

This module is meant to be imported qualified as @Assert@

-}

{-# LANGUAGE CPP #-}
module Data.LLVM.BitCode.Assert
  ( failWithMsg
  , unknownEntity
  -- ** Record size
  , recordSizeLess
  , recordSizeGreater
  , recordSizeBetween
  , recordSizeIn

  -- ** Types
  , elimPtrTo
  , elimPtrTo_
  ) where

import           Control.Monad (MonadPlus, mplus)
import           Control.Monad (when)
#if !MIN_VERSION_base(4,13,0)
import           Control.Monad.Fail (MonadFail)
#endif
import           Data.LLVM.BitCode.Record (Record)
import qualified Data.LLVM.BitCode.Record as Record
import           Text.LLVM.AST (Type', Ident)
import qualified Text.LLVM.AST as AST

supportedCompilerMessage :: [String]
supportedCompilerMessage :: [String]
supportedCompilerMessage =
  [ String
"Are you sure you're using a supported compiler?"
  , String
"Check here: https://github.com/GaloisInc/llvm-pretty-bc-parser"
  ]

-- | Call 'fail' with a helpful hint to the user
failWithMsg :: MonadFail m => String -> m a
failWithMsg :: forall (m :: * -> *) a. MonadFail m => String -> m a
failWithMsg String
s = String -> m a
forall a. String -> m a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> m a) -> String -> m a
forall a b. (a -> b) -> a -> b
$ [String] -> String
unlines (String
sString -> [String] -> [String]
forall a. a -> [a] -> [a]
:[String]
supportedCompilerMessage)

-- | For when an unknown value of an enumeration is encountered
unknownEntity :: (MonadFail m, Show a) => String -> a -> m b
unknownEntity :: forall (m :: * -> *) a b.
(MonadFail m, Show a) =>
String -> a -> m b
unknownEntity String
sort a
val = String -> m b
forall (m :: * -> *) a. MonadFail m => String -> m a
failWithMsg (String
"Unknown " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
sort String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" " String -> String -> String
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Show a => a -> String
show a
val)

----------------------------------------------------------------
-- ** Record sizes

recordSizeCmp :: MonadFail m => String -> (Int -> Bool) -> Record -> m ()
recordSizeCmp :: forall (m :: * -> *).
MonadFail m =>
String -> (Int -> Bool) -> Record -> m ()
recordSizeCmp String
msg Int -> Bool
compare_ Record
record =
  let len :: Int
len = [Field] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Record -> [Field]
Record.recordFields Record
record)
  in Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int -> Bool
compare_ Int
len) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ String -> m ()
forall (m :: * -> *) a. MonadFail m => String -> m a
failWithMsg (String -> m ()) -> String -> m ()
forall a b. (a -> b) -> a -> b
$ [String] -> String
unlines ([String] -> String) -> [String] -> String
forall a b. (a -> b) -> a -> b
$
       [ String
"Invalid record size: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
len, String
msg ]

recordSizeLess :: MonadFail m => Record -> Int -> m ()
recordSizeLess :: forall (m :: * -> *). MonadFail m => Record -> Int -> m ()
recordSizeLess Record
r Int
i = String -> (Int -> Bool) -> Record -> m ()
forall (m :: * -> *).
MonadFail m =>
String -> (Int -> Bool) -> Record -> m ()
recordSizeCmp String
"Expected size less than" (Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<=) Record
r

recordSizeGreater :: MonadFail m => Record -> Int -> m ()
recordSizeGreater :: forall (m :: * -> *). MonadFail m => Record -> Int -> m ()
recordSizeGreater Record
r Int
i = String -> (Int -> Bool) -> Record -> m ()
forall (m :: * -> *).
MonadFail m =>
String -> (Int -> Bool) -> Record -> m ()
recordSizeCmp String
"Expected size greater than" (Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
i) Record
r

recordSizeBetween :: MonadFail m => Record -> Int -> Int -> m ()
recordSizeBetween :: forall (m :: * -> *). MonadFail m => Record -> Int -> Int -> m ()
recordSizeBetween Record
record Int
lb Int
ub =
  Record -> Int -> m ()
forall (m :: * -> *). MonadFail m => Record -> Int -> m ()
recordSizeGreater Record
record Int
lb m () -> m () -> m ()
forall a b. m a -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Record -> Int -> m ()
forall (m :: * -> *). MonadFail m => Record -> Int -> m ()
recordSizeLess Record
record Int
ub

recordSizeIn :: MonadFail m => Record -> [Int] -> m ()
recordSizeIn :: forall (m :: * -> *). MonadFail m => Record -> [Int] -> m ()
recordSizeIn Record
record [Int]
ns =
  let len :: Int
len = [Field] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Record -> [Field]
Record.recordFields Record
record)
  in Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Bool -> Bool
not (Int
len Int -> [Int] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Int]
ns)) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ String -> m ()
forall (m :: * -> *) a. MonadFail m => String -> m a
failWithMsg (String -> m ()) -> String -> m ()
forall a b. (a -> b) -> a -> b
$ [String] -> String
unlines ([String] -> String) -> [String] -> String
forall a b. (a -> b) -> a -> b
$
       [ String
"Invalid record size: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
len
       , String
"Expected one of: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ [Int] -> String
forall a. Show a => a -> String
show [Int]
ns
       ]


----------------------------------------------------------------
-- ** Types

-- | Assert that this thing is a @'PtrTo' ty@ and return the underlying @ty@.
--
-- Think carefully before using this function, as it will not work as you would
-- expect when the type is an opaque pointer.
-- See @Note [Pointers and pointee types]@.
elimPtrTo :: (MonadFail m, MonadPlus m) => String -> Type' Ident -> m (Type' Ident)
elimPtrTo :: forall (m :: * -> *).
(MonadFail m, MonadPlus m) =>
String -> Type' Ident -> m (Type' Ident)
elimPtrTo String
msg Type' Ident
ptrTy = Type' Ident -> m (Type' Ident)
forall (m :: * -> *). MonadPlus m => Type' Ident -> m (Type' Ident)
AST.elimPtrTo Type' Ident
ptrTy m (Type' Ident) -> m (Type' Ident) -> m (Type' Ident)
forall a. m a -> m a -> m a
forall (m :: * -> *) a. MonadPlus m => m a -> m a -> m a
`mplus`
                        (String -> m (Type' Ident)
forall a. String -> m a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> m (Type' Ident)) -> String -> m (Type' Ident)
forall a b. (a -> b) -> a -> b
$ [String] -> String
unlines [ String
msg
                                        , String
"Expected pointer type, found:"
                                        , Type' Ident -> String
forall a. Show a => a -> String
show Type' Ident
ptrTy
                                        ])

-- | Assert that this thing is a 'PtrTo' type.
--
-- Think carefully before using this function, as it will not work as you would
-- expect when the type is an opaque pointer.
-- See @Note [Pointers and pointee types]@.
elimPtrTo_ :: (MonadFail m, MonadPlus m) => String -> Type' Ident -> m ()
elimPtrTo_ :: forall (m :: * -> *).
(MonadFail m, MonadPlus m) =>
String -> Type' Ident -> m ()
elimPtrTo_ String
msg Type' Ident
ptrTy = String -> Type' Ident -> m (Type' Ident)
forall (m :: * -> *).
(MonadFail m, MonadPlus m) =>
String -> Type' Ident -> m (Type' Ident)
elimPtrTo String
msg Type' Ident
ptrTy m (Type' Ident) -> m () -> m ()
forall a b. m a -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

{-
Note [Pointers and pointee types]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Unlike LLVM itself, llvm-pretty and llvm-pretty-bc-parser allow mixing opaque
and non-opaque pointers. A consequence of this is that we generally avoid
pattern matching on PtrTo (non-opaque pointer) types and inspecting the
underlying pointee types. This sort of code simply won't work for PtrOpaque
types, which lack pointee types.

The elimPtrTo and elimPtrTo_ functions go against this rule, as they retrieve
the pointee type in a PtrTo. These functions are primarily used for supporting
old versions of LLVM which do not store the necessary type information in the
instruction itself.
-}