{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE CPP #-}
module Data.LLVM.BitCode.IR.Constants where
import Data.LLVM.BitCode.Bitstream
import Data.LLVM.BitCode.IR.Values
import Data.LLVM.BitCode.Match
import Data.LLVM.BitCode.Parse
import Data.LLVM.BitCode.Record
import Text.LLVM.AST
import Control.Monad (mplus,mzero,foldM,(<=<),when)
import Control.Monad.ST (runST,ST)
import Data.Array.ST (newArray,readArray,MArray,STUArray)
import Data.Bits (shiftL,shiftR,testBit)
import Data.Char (chr)
import Data.Maybe (fromMaybe)
import Data.Word (Word32,Word64)
import qualified Data.Map as Map
#if __GLASGOW_HASKELL__ >= 704
import Data.Array.Unsafe (castSTUArray)
#else
import Data.Array.ST (castSTUArray)
#endif
binopGeneric :: forall a.
(ArithOp -> Typed PValue -> PValue -> a)
-> (BitOp -> Typed PValue -> PValue -> a)
-> Match Field (Maybe Int -> Typed PValue -> PValue -> a)
binopGeneric aop bop = choose <=< numeric
where
constant k kf = return $ \_mb x y ->
case typedType x of
PrimType (FloatType _) -> kf x y
_ -> k x y
nuw x = testBit x 0
nsw x = testBit x 1
wrapFlags i k kf = return $ \ mb x y ->
case typedType x of
PrimType (FloatType _) -> i kf x y
_ ->
case mb of
Nothing -> i (k False False) x y
Just w -> i (k (nuw w) (nsw w)) x y
exact x = testBit x 0
exactFlag i k kf = return $ \ mb x y ->
case typedType x of
PrimType (FloatType _) -> i kf x y
_ ->
case mb of
Nothing -> i (k False) x y
Just w -> i (k (exact w)) x y
choose :: Match Int (Maybe Int -> Typed PValue -> PValue -> a)
choose 0 = wrapFlags aop Add FAdd
choose 1 = wrapFlags aop Sub FSub
choose 2 = wrapFlags aop Mul FMul
choose 3 = exactFlag aop UDiv FDiv
choose 4 = exactFlag aop SDiv FDiv
choose 5 = constant (aop URem) (aop FRem)
choose 6 = constant (aop SRem) (aop FRem)
choose 7 = wrapFlags bop Shl (error "invalid shl on floating point")
choose 8 = exactFlag bop Lshr (error "invalid lshr on floating point")
choose 9 = exactFlag bop Ashr (error "invalid ashr on floating point")
choose 10 = constant (bop And) (error "invalid and on floating point")
choose 11 = constant (bop Or) (error "invalid or on floating point")
choose 12 = constant (bop Xor) (error "invalid xor on floating point")
choose _ = mzero
binop :: Match Field (Maybe Int -> Typed PValue -> PValue -> PInstr)
binop = binopGeneric Arith Bit
binopCE :: Match Field (Maybe Int -> Typed PValue -> PValue -> PValue)
binopCE = binopGeneric aop bop
where
aop op tv v = ValConstExpr (ConstArith op tv v)
bop op tv v = ValConstExpr (ConstBit op tv v)
fcmpOp :: Match Field FCmpOp
fcmpOp = choose <=< numeric
where
choose :: Match Int FCmpOp
choose 0 = return Ffalse
choose 1 = return Foeq
choose 2 = return Fogt
choose 3 = return Foge
choose 4 = return Folt
choose 5 = return Fole
choose 6 = return Fone
choose 7 = return Ford
choose 8 = return Funo
choose 9 = return Fueq
choose 10 = return Fugt
choose 11 = return Fuge
choose 12 = return Fult
choose 13 = return Fule
choose 14 = return Fune
choose 15 = return Ftrue
choose _ = mzero
icmpOp :: Match Field ICmpOp
icmpOp = choose <=< numeric
where
choose :: Match Int ICmpOp
choose 32 = return Ieq
choose 33 = return Ine
choose 34 = return Iugt
choose 35 = return Iuge
choose 36 = return Iult
choose 37 = return Iule
choose 38 = return Isgt
choose 39 = return Isge
choose 40 = return Islt
choose 41 = return Isle
choose _ = mzero
castOpGeneric :: forall c. (ConvOp -> Maybe c) -> Match Field c
castOpGeneric op = choose <=< numeric
where
choose :: Match Int c
choose 0 = op Trunc
choose 1 = op ZExt
choose 2 = op SExt
choose 3 = op FpToUi
choose 4 = op FpToSi
choose 5 = op UiToFp
choose 6 = op SiToFp
choose 7 = op FpTrunc
choose 8 = op FpExt
choose 9 = op PtrToInt
choose 10 = op IntToPtr
choose 11 = op BitCast
choose _ = mzero
castOp :: Match Field (Typed PValue -> Type -> PInstr)
castOp = castOpGeneric (return . Conv)
castOpCE :: Match Field (Typed PValue -> Type -> PValue)
castOpCE = castOpGeneric op
where
op c = return (\ tv t -> ValConstExpr (ConstConv c tv t))
type ConstantTable = Map.Map Int (Typed Value)
cstGep :: Match Entry Record
cstGep = hasRecordCode 12 <=< fromEntry
cstInboundsGep :: Match Entry Record
cstInboundsGep = hasRecordCode 20 <=< fromEntry
setCurType :: Int -> Parse Type
setCurType = getType'
parseConstantsBlock :: [Entry] -> Parse ()
parseConstantsBlock es = fixValueTable_ $ \ vs' -> do
let curTy = fail "no current type id set"
(_,vs) <- foldM (parseConstantEntry vs') (curTy,[]) es
return vs
parseConstantEntry :: ValueTable -> (Parse Type,[Typed PValue]) -> Entry
-> Parse (Parse Type, [Typed PValue])
parseConstantEntry t (getTy,cs) (fromEntry -> Just r) =
label "CONSTANTS_BLOCK" $ case recordCode r of
1 -> label "CST_CODE_SETTYPE" $ do
let field = parseField r
i <- field 0 numeric
return (setCurType i, cs)
2 -> label "CST_CODE_NULL" $ do
ty <- getTy
val <- resolveNull ty
return (getTy, Typed ty val:cs)
3 -> label "CST_CODE_UNDEF" $ do
ty <- getTy
return (getTy, Typed ty ValUndef:cs)
4 -> label "CST_CODE_INTEGER" $ do
let field = parseField r
ty <- getTy
n <- field 0 signedWord64
let val = fromMaybe (ValInteger (toInteger n)) $ do
Integer 0 <- elimPrimType ty
return (ValBool (n /= 0))
return (getTy, Typed ty val:cs)
5 -> label "CST_CODE_WIDE_INTEGER" $ do
ty <- getTy
n <- parseWideInteger r
return (getTy, Typed ty (ValInteger n):cs)
6 -> label "CST_CODE_FLOAT" $ do
let field = parseField r
ty <- getTy
ft <- (elimFloatType =<< elimPrimType ty)
`mplus` fail "expecting a float type"
let build k = do
w <- field 0 numeric
return (getTy, (Typed ty $! k w):cs)
case ft of
Float -> build (ValFloat . castFloat)
_ -> build (ValDouble . castDouble)
7 -> label "CST_CODE_AGGREGATE" $ do
ty <- getTy
elems <- parseField r 0 (fieldArray numeric)
`mplus` parseFields r 0 numeric
cxt <- getContext
let vals = [forwardRef cxt ix t | ix <- elems ]
case ty of
Struct _fs ->
return (getTy, Typed ty (ValStruct vals):cs)
PackedStruct _fs ->
return (getTy, Typed ty (ValPackedStruct vals):cs)
Array _n fty ->
return (getTy, Typed ty (ValArray fty (map typedValue vals)):cs)
Vector _n ety -> do
return (getTy, Typed ty (ValVector ety (map typedValue vals)):cs)
_ -> return (getTy, Typed ty ValUndef:cs)
8 -> label "CST_CODE_STRING" $ do
let field = parseField r
ty <- getTy
values <- field 0 string
return (getTy, Typed ty (ValString values):cs)
9 -> label "CST_CODE_CSTRING" $ do
ty <- getTy
values <- parseField r 0 cstring
`mplus` parseFields r 0 (fieldChar6 ||| char)
return (getTy, Typed ty (ValString (values ++ [chr 0])):cs)
10 -> label "CST_CODE_CE_BINOP" $ do
let field = parseField r
ty <- getTy
mkInstr <- field 0 binopCE
lopval <- field 1 numeric
ropval <- field 2 numeric
cxt <- getContext
let lv = forwardRef cxt lopval t
rv = forwardRef cxt ropval t
let mbWord = numeric =<< fieldAt 3 r
return (getTy, Typed ty (mkInstr mbWord lv (typedValue rv)) : cs)
11 -> label "CST_CODE_CE_CAST" $ do
let field = parseField r
ty <- getTy
cast' <- field 0 castOpCE
opval <- field 2 numeric
cxt <- getContext
return (getTy,Typed ty (cast' (forwardRef cxt opval t) ty):cs)
12 -> label "CST_CODE_CE_GEP" $ do
ty <- getTy
v <- parseCeGep False t r
return (getTy,Typed ty v:cs)
13 -> label "CST_CODE_CE_SELECT" $ do
let field = parseField r
ty <- getTy
ix1 <- field 0 numeric
ix2 <- field 1 numeric
ix3 <- field 2 numeric
cxt <- getContext
let ref ix = forwardRef cxt ix t
ce = ConstSelect (ref ix1) (ref ix2) (ref ix3)
return (getTy, Typed ty (ValConstExpr ce):cs)
14 -> label "CST_CODE_CE_EXTRACTELT" $ do
notImplemented
15 -> label "CST_CODE_CE_INSERTELT" $ do
notImplemented
16 -> label "CST_CODE_CE_SHUFFLEVEC" $ do
notImplemented
17 -> label "CST_CODE_CE_CMP" $ do
let field = parseField r
opty <- getType =<< field 0 numeric
op0 <- getConstantFwdRef t opty =<< field 1 numeric
op1 <- getConstantFwdRef t opty =<< field 2 numeric
let isFloat = isPrimTypeOf isFloatingPoint
cst <- if isFloat opty || isVectorOf isFloat opty
then do op <- field 3 fcmpOp
return (ConstFCmp op op0 op1)
else do op <- field 3 icmpOp
return (ConstICmp op op0 op1)
return (getTy, Typed (PrimType (Integer 1)) (ValConstExpr cst):cs)
18 -> label "CST_CODE_INLINEASM_OLD" $ do
let field = parseField r
ty <- getTy
flags <- field 0 numeric
let sideEffect = testBit (flags :: Int) 0
alignStack = (flags `shiftR` 1) == 1
alen <- field 1 numeric
asm <- parseSlice r 2 alen char
clen <- field (2+alen) numeric
cst <- parseSlice r (3+alen) clen char
return (getTy, Typed ty (ValAsm sideEffect alignStack asm cst):cs)
19 -> label "CST_CODE_CE_SHUFFLEVEC_EX" $ do
notImplemented
20 -> label "CST_CODE_CE_INBOUNDS_GEP" $ do
ty <- getTy
v <- parseCeGep True t r
return (getTy,Typed ty v:cs)
21 -> label "CST_CODE_BLOCKADDRESS" $ do
let field = parseField r
ty <- getTy
val <- getValue ty =<< field 1 numeric
bid <- field 2 numeric
sym <- elimValSymbol (typedValue val)
`mplus` fail "invalid function symbol in BLOCKADDRESS record"
let ce = ConstBlockAddr sym bid
return (getTy, Typed ty (ValConstExpr ce):cs)
22 -> label "CST_CODE_DATA" $ do
ty <- getTy
elemTy <- (elimPrimType =<< elimSequentialType ty)
`mplus` fail "invalid container type for CST_CODE_DATA"
let build mk = do
ns <- parseFields r 0 numeric
let elems = map mk ns
val | isArray ty = ValArray (PrimType elemTy) elems
| otherwise = ValVector (PrimType elemTy) elems
return (getTy, Typed ty val : cs)
case elemTy of
Integer 8 -> build ValInteger
Integer 16 -> build ValInteger
Integer 32 -> build ValInteger
Integer 64 -> build ValInteger
FloatType Float -> build ValFloat
FloatType Double -> build ValDouble
_ -> fail "unknown element type in CE_DATA"
23 -> label "CST_CODE_INLINEASM" $ do
let field = parseField r
mask <- field 0 numeric
let test = testBit (mask :: Word32)
hasSideEffects = test 0
isAlignStack = test 1
_asmDialect = mask `shiftR` 2
let len = length (recordFields r)
asmStrSize <- field 1 numeric
when (2 + asmStrSize >= len)
(fail "Invalid record")
constStrSize <- field (2 + asmStrSize) numeric
when (3 + asmStrSize + constStrSize > len)
(fail "Invalid record")
asmStr <- parseSlice r 2 asmStrSize char
constStr <- parseSlice r (3 + asmStrSize) constStrSize char
ty <- getTy
let val = ValAsm hasSideEffects isAlignStack asmStr constStr
return (getTy, Typed ty val : cs)
code -> fail ("unknown constant record code: " ++ show code)
parseConstantEntry _ st (abbrevDef -> Just _) =
return st
parseConstantEntry _ _ e =
fail ("constant block: unexpected: " ++ show e)
parseCeGep :: Bool -> ValueTable -> Record -> Parse PValue
parseCeGep isInbounds t r = do
let isExplicit = odd (length (recordFields r))
firstIdx = if isExplicit then 1 else 0
field = parseField r
loop n = do
ty <- getType =<< field n numeric
elt <- field (n+1) numeric
rest <- loop (n+2) `mplus` return []
cxt <- getContext
return (Typed ty (typedValue (forwardRef cxt elt t)) : rest)
mPointeeType <-
if isExplicit
then Just <$> (getType =<< field 0 numeric)
else pure Nothing
args <- loop firstIdx
return $! ValConstExpr (ConstGEP isInbounds mPointeeType args)
parseWideInteger :: Record -> Parse Integer
parseWideInteger r = do
limbs <- parseFields r 0 signedWord64
return (foldr (\l acc -> acc `shiftL` 64 + (toInteger l)) 0 limbs)
resolveNull :: Type -> Parse PValue
resolveNull ty = case typeNull ty of
HasNull nv -> return nv
ResolveNull i -> resolveNull =<< getType' =<< getTypeId i
castFloat :: Word32 -> Float
castFloat w = runST (cast w)
castDouble :: Word64 -> Double
castDouble w = runST (cast w)
cast :: (MArray (STUArray s) a (ST s), MArray (STUArray s) b (ST s))
=> a -> ST s b
cast x = do
arr <- newArray (0 :: Int, 0) x
res <- castSTUArray arr
readArray res 0