{-# LANGUAGE MultiParamTypeClasses #-}
module LLVM.Internal.OrcJIT where
import LLVM.Prelude
import Control.Exception
import Control.Monad.AnyCont
import Control.Monad.IO.Class
import Data.Bits
import Data.ByteString (packCString, useAsCString)
import Data.IORef
import Foreign.C.String
import Foreign.Ptr
import LLVM.Internal.Coding
import LLVM.Internal.Target
import qualified LLVM.Internal.FFI.DataLayout as FFI
import qualified LLVM.Internal.FFI.LLVMCTypes as FFI
import qualified LLVM.Internal.FFI.OrcJIT as FFI
import qualified LLVM.Internal.FFI.Target as FFI
newtype MangledSymbol = MangledSymbol ByteString
deriving (Show, Eq, Ord)
instance EncodeM (AnyContT IO) MangledSymbol CString where
encodeM (MangledSymbol bs) = anyContToM $ useAsCString bs
instance MonadIO m => DecodeM m MangledSymbol CString where
decodeM str = liftIO $ MangledSymbol <$> packCString str
newtype ExecutionSession = ExecutionSession (Ptr FFI.ExecutionSession)
data JITSymbolFlags =
JITSymbolFlags {
jitSymbolWeak :: !Bool
, jitSymbolCommon :: !Bool
, jitSymbolAbsolute :: !Bool
, jitSymbolExported :: !Bool
}
deriving (Show, Eq, Ord)
defaultJITSymbolFlags :: JITSymbolFlags
defaultJITSymbolFlags = JITSymbolFlags False False False False
data JITSymbol =
JITSymbol {
jitSymbolAddress :: !WordPtr,
jitSymbolFlags :: !JITSymbolFlags
}
deriving (Show, Eq, Ord)
data JITSymbolError = JITSymbolError ShortByteString
deriving (Show, Eq)
newtype SymbolResolver =
SymbolResolver (MangledSymbol -> IO (Either JITSymbolError JITSymbol))
withSymbolResolver :: ExecutionSession -> SymbolResolver -> (Ptr FFI.SymbolResolver -> IO a) -> IO a
withSymbolResolver (ExecutionSession es) (SymbolResolver resolverFn) f =
bracket (FFI.wrapSymbolResolverFn resolverFn') freeHaskellFunPtr $ \resolverPtr ->
bracket (FFI.createLambdaResolver es resolverPtr) FFI.disposeSymbolResolver $ \resolver ->
f resolver
where
resolverFn' symbol result = do
setSymbol <- encodeM =<< resolverFn =<< decodeM symbol
setSymbol result
instance Monad m => EncodeM m JITSymbolFlags FFI.JITSymbolFlags where
encodeM f = return $ foldr1 (.|.) [
if a f
then b
else 0
| (a,b) <- [
(jitSymbolWeak, FFI.jitSymbolFlagsWeak),
(jitSymbolCommon, FFI.jitSymbolFlagsCommon),
(jitSymbolAbsolute, FFI.jitSymbolFlagsAbsolute),
(jitSymbolExported, FFI.jitSymbolFlagsExported)
]
]
instance Monad m => DecodeM m JITSymbolFlags FFI.JITSymbolFlags where
decodeM f =
return $ JITSymbolFlags {
jitSymbolWeak = FFI.jitSymbolFlagsWeak .&. f /= 0,
jitSymbolCommon = FFI.jitSymbolFlagsCommon .&. f /= 0,
jitSymbolAbsolute = FFI.jitSymbolFlagsAbsolute .&. f /= 0,
jitSymbolExported = FFI.jitSymbolFlagsExported .&. f /= 0
}
instance MonadIO m => EncodeM m (Either JITSymbolError JITSymbol) (Ptr FFI.JITSymbol -> IO ()) where
encodeM (Left (JITSymbolError _)) = return $ \jitSymbol ->
FFI.setJITSymbol jitSymbol (FFI.TargetAddress 0) FFI.jitSymbolFlagsHasError
encodeM (Right (JITSymbol addr flags)) = return $ \jitSymbol -> do
flags' <- encodeM flags
FFI.setJITSymbol jitSymbol (FFI.TargetAddress (fromIntegral addr)) flags'
instance (MonadIO m, MonadAnyCont IO m) => DecodeM m (Either JITSymbolError JITSymbol) (Ptr FFI.JITSymbol) where
decodeM jitSymbol = do
errMsg <- alloca
FFI.TargetAddress addr <- liftIO $ FFI.getAddress jitSymbol errMsg
rawFlags <- liftIO (FFI.getFlags jitSymbol)
if addr == 0 || (rawFlags .&. FFI.jitSymbolFlagsHasError /= 0)
then do
errMsg <- decodeM =<< liftIO (FFI.getErrorMsg jitSymbol)
pure (Left (JITSymbolError errMsg))
else do
flags <- decodeM rawFlags
pure (Right (JITSymbol (fromIntegral addr) flags))
instance MonadIO m =>
EncodeM m SymbolResolver (IORef [IO ()] -> Ptr FFI.ExecutionSession -> IO (Ptr FFI.SymbolResolver)) where
encodeM (SymbolResolver resolverFn) = return $ \cleanups es -> do
resolverFn' <- allocFunPtr cleanups (encodeM resolverFn)
allocWithCleanup cleanups (FFI.createLambdaResolver es resolverFn') FFI.disposeSymbolResolver
instance MonadIO m => EncodeM m (MangledSymbol -> IO (Either JITSymbolError JITSymbol)) (FunPtr FFI.SymbolResolverFn) where
encodeM callback =
liftIO $ FFI.wrapSymbolResolverFn
(\symbol result -> do
setSymbol <- encodeM =<< callback =<< decodeM symbol
setSymbol result)
allocWithCleanup :: IORef [IO ()] -> IO a -> (a -> IO ()) -> IO a
allocWithCleanup cleanups alloc free = mask $ \restore -> do
a <- restore alloc
modifyIORef cleanups (free a :)
pure a
allocFunPtr :: IORef [IO ()] -> IO (FunPtr a) -> IO (FunPtr a)
allocFunPtr cleanups alloc = allocWithCleanup cleanups alloc freeHaskellFunPtr
createRegisteredDataLayout :: (MonadAnyCont IO m) => TargetMachine -> IORef [IO ()] -> m (Ptr FFI.DataLayout)
createRegisteredDataLayout (TargetMachine tm) cleanups =
let createDataLayout = do
dl <- FFI.createTargetDataLayout tm
modifyIORef' cleanups (FFI.disposeDataLayout dl :)
pure dl
in anyContToM $ bracketOnError createDataLayout FFI.disposeDataLayout
createExecutionSession :: IO ExecutionSession
createExecutionSession = ExecutionSession <$> FFI.createExecutionSession
disposeExecutionSession :: ExecutionSession -> IO ()
disposeExecutionSession (ExecutionSession es) = FFI.disposeExecutionSession es
withExecutionSession :: (ExecutionSession -> IO a) -> IO a
withExecutionSession = bracket createExecutionSession disposeExecutionSession
allocateModuleKey :: ExecutionSession -> IO FFI.ModuleKey
allocateModuleKey (ExecutionSession es) = FFI.allocateVModule es
releaseModuleKey :: ExecutionSession -> FFI.ModuleKey -> IO ()
releaseModuleKey (ExecutionSession es) k = FFI.releaseVModule es k
withModuleKey :: ExecutionSession -> (FFI.ModuleKey -> IO a) -> IO a
withModuleKey es = bracket (allocateModuleKey es) (releaseModuleKey es)