{-# language FlexibleInstances #-}
{-# language FlexibleContexts #-}
{-# language ViewPatterns #-}
{-# language TypeFamilies #-}
{-# language CPP #-}
module CodeGen.X86.FFI where

-------------------------------------------------------

import Control.Monad
import Control.Exception (evaluate)
import Control.DeepSeq
import Foreign
import Foreign.C.Types
import Foreign.ForeignPtr
import Foreign.ForeignPtr.Unsafe
import System.IO.Unsafe

import Control.DeeperSeq
import CodeGen.X86.Asm
import CodeGen.X86.CodeGen

-------------------------------------------------------

-- this should be queried from the OS dynamically...
#define PAGE_SIZE 4096

#if defined (mingw32_HOST_OS) || defined (mingw64_HOST_OS) 

import System.Win32.Types
import System.Win32.Mem
import Foreign.Marshal.Alloc

#endif

class (MapResult a, NFData (Result a)) => Callable a where dynCCall :: FunPtr a -> a

{-# NOINLINE callForeignPtr #-}
callForeignPtr :: Callable a => IO (ForeignPtr a) -> a
callForeignPtr :: forall a. Callable a => IO (ForeignPtr a) -> a
callForeignPtr IO (ForeignPtr a)
p_ = forall a b. MapResult a => (Result a -> b) -> a -> SetResult b a
mapResult Result a -> Result a
f (forall a. Callable a => FunPtr a -> a
dynCCall forall a b. (a -> b) -> a -> b
$ forall {a}. Ptr a -> FunPtr a
cast forall a b. (a -> b) -> a -> b
$ forall a. ForeignPtr a -> Ptr a
unsafeForeignPtrToPtr ForeignPtr a
p)
  where
    cast :: Ptr a -> FunPtr a
cast = forall a b. Ptr a -> FunPtr b
castPtrToFunPtr :: Ptr a -> FunPtr a
    p :: ForeignPtr a
p = forall a. IO a -> a
unsafePerformIO IO (ForeignPtr a)
p_
    {-# NOINLINE f #-}
    f :: Result a -> Result a
f Result a
x = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a. a -> IO a
evaluate (forall a. NFData a => a -> a
force Result a
x) forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* forall a. ForeignPtr a -> IO ()
touchForeignPtr ForeignPtr a
p

class MapResult a => CallableHs a where createHsPtr :: a -> IO (FunPtr a)

hsPtr :: CallableHs a => a -> FunPtr a
hsPtr :: forall a. CallableHs a => a -> FunPtr a
hsPtr a
x = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a. CallableHs a => a -> IO (FunPtr a)
createHsPtr a
x
{- TODO
hsPtr' :: CallableHs a => a -> ForeignPtr a
hsPtr' x = unsafePerformIO $ do
    p <- createHsPtr x
    newForeignPtr (freeHaskellFunPtr p) (castFunPtrToPtr p)
-}

#if defined (mingw32_HOST_OS) || defined (mingw64_HOST_OS) 
-- note: GHC 64 bit also defines mingw32 ...

foreign import ccall "static malloc.h  _aligned_malloc" c_aligned_malloc :: CSize -> CSize -> IO (Ptr a)
foreign import ccall "static malloc.h  _aligned_free"   c_aligned_free   :: Ptr a -> IO ()
foreign import ccall "static malloc.h &_aligned_free"   ptr_aligned_free :: FunPtr (Ptr a -> IO ())

#elif defined (linux_HOST_OS)

-- on Linux too, we should use posix_memalign...
foreign import ccall "static stdlib.h memalign"   memalign :: CSize -> CSize -> IO (Ptr a)
foreign import ccall "static stdlib.h &free"      stdfree  :: FunPtr (Ptr a -> IO ())
foreign import ccall "static sys/mman.h mprotect" mprotect :: Ptr a -> CSize -> CInt -> IO CInt

#elif defined (darwin_HOST_OS) || defined (freebsd_HOST_OS) || defined (openbsd_HOST_OS) || defined (netbsd_HOST_OS) 

foreign import ccall "static stdlib.h posix_memalign"   posix_memalign :: Ptr (Ptr a) -> CSize -> CSize -> IO CInt
foreign import ccall "static stdlib.h &free"            stdfree        :: FunPtr (Ptr a -> IO ())
foreign import ccall "static sys/mman.h mprotect"       mprotect       :: Ptr a -> CSize -> CInt -> IO CInt

#endif

-------------------------------------------------------

#if defined (mingw32_HOST_OS) || defined (mingw64_HOST_OS) 

flag_PAGE_EXECUTE_READWRITE :: Word32
flag_PAGE_EXECUTE_READWRITE = 0x40 

{-# NOINLINE compile #-}
compile :: Callable a => Code -> a
compile x = callForeignPtr $ do
    let (bytes, size) = buildTheCode x
    arr <- c_aligned_malloc (fromIntegral size) PAGE_SIZE
    _ <- virtualProtect (castPtr arr) (fromIntegral size) flag_PAGE_EXECUTE_READWRITE
    forM_ [p | Right p <- bytes] $ uncurry $ pokeByteOff arr    
    newForeignPtr ptr_aligned_free arr 

#elif defined linux_HOST_OS

{-# NOINLINE compile #-}
compile :: Callable a => Code -> a
compile :: forall a. Callable a => Code -> a
compile Code
x = forall a. Callable a => IO (ForeignPtr a) -> a
callForeignPtr forall a b. (a -> b) -> a -> b
$ do
    let (CodeBuilderRes
bytes, forall a b. (Integral a, Num b) => a -> b
fromIntegral -> CSize
size) = Code -> (CodeBuilderRes, Int)
buildTheCode Code
x
    Ptr a
arr <- forall a. CSize -> CSize -> IO (Ptr a)
memalign PAGE_SIZE size
    CInt
_ <- forall a. Ptr a -> CSize -> CInt -> IO CInt
mprotect Ptr a
arr CSize
size CInt
0x7 -- READ, WRITE, EXEC
    forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(Int, Word8)
p | Right (Int, Word8)
p <- CodeBuilderRes
bytes] forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry forall a b. (a -> b) -> a -> b
$ forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr a
arr
    forall a. FinalizerPtr a -> Ptr a -> IO (ForeignPtr a)
newForeignPtr forall a. FunPtr (Ptr a -> IO ())
stdfree Ptr a
arr

#elif defined (darwin_HOST_OS) || defined (freebsd_HOST_OS) || defined (openbsd_HOST_OS) || defined (netbsd_HOST_OS) 

-- | This calls @posix_memalign()@
posixMemAlign 
  :: CSize               -- ^ alignment
  -> CSize               -- ^ size
  -> IO (Ptr a)
posixMemAlign alignment size0 =
  alloca $ \pp -> do
    let a    = max alignment 8
        size = mod (size0 + a - 1) a      -- size *must* be a multiple of both alignment and sizeof(void*)
    res <- posix_memalign pp alignment size
    case res of
      0 -> peek pp
      _ -> error "posix_memalign failed"
      
{-# NOINLINE compile #-}
compile :: Callable a => Code -> a
compile x = callForeignPtr $ do
    let (bytes, fromIntegral -> size) = buildTheCode x
    arr <- posixMemAlign PAGE_SIZE size
    _ <- mprotect arr size 0x7 -- READ, WRITE, EXEC
    forM_ [p | Right p <- bytes] $ uncurry $ pokeByteOff arr
    newForeignPtr stdfree arr

#endif