{-# LANGUAGE MagicHash, UnboxedTuples #-}
{-# LANGUAGE FlexibleInstances, FlexibleContexts, UndecidableInstances, FunctionalDependencies #-}
{-# LANGUAGE DataKinds, PolyKinds, TypeFamilies #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE CPP #-}

#include "MachDeps.h"

module Language.Asm.Inline
( defineAsmFun
, defineAsmFunM
, Unit(..)
) where

import qualified Data.ByteString as BS
import Control.Monad
import Control.Monad.Primitive
import Data.Generics.Uniplate.Data
import Data.List
import Foreign.Ptr
import GHC.Int
import GHC.Prim
import GHC.Ptr
import GHC.Types hiding (Type)
import GHC.Word
import Language.Haskell.TH
import Language.Haskell.TH.Syntax
import System.IO.Unsafe

import Language.Asm.Inline.AsmCode
import Language.Asm.Inline.Util

class AsmArg a (unboxedTy :: TYPE rep) | a -> unboxedTy where
  unbox :: a -> unboxedTy
  rebox :: unboxedTy -> a

data Unit = Unit

#if MIN_VERSION_GLASGOW_HASKELL(9,2,0,0)
type Int8Rep# = Int8#
type Int16Rep# = Int16#
type Int32Rep# = Int32#
type Word8Rep# = Word8#
type Word16Rep# = Word16#
type Word32Rep# = Word32#
#else
type Int8Rep# = Int#
type Int16Rep# = Int#
type Int32Rep# = Int#
type Int64Rep# = Int#
type Word8Rep# = Word#
type Word16Rep# = Word#
type Word32Rep# = Word#
#endif

#if MIN_VERSION_GLASGOW_HASKELL(9,4,0,0)
type Int64Rep# = Int64#
type Word64Rep# = Word64#
#else
type Int64Rep# = Int#
type Word64Rep# = Word#
#endif

instance AsmArg Unit Int# where
  unbox :: Unit -> Int#
unbox Unit
_ = Int#
0#
  rebox :: Int# -> Unit
rebox Int#
_ = Unit
Unit

instance AsmArg Int Int# where
  unbox :: Int -> Int#
unbox (I# Int#
w) = Int#
w
  rebox :: Int# -> Int
rebox = Int# -> Int
I#

instance AsmArg Int8 Int8Rep# where
  unbox :: Int8 -> Int8Rep#
unbox (I8# Int8Rep#
w) = Int8Rep#
w
  rebox :: Int8Rep# -> Int8
rebox = Int8Rep# -> Int8
I8#

instance AsmArg Int16 Int16Rep# where
  unbox :: Int16 -> Int16Rep#
unbox (I16# Int16Rep#
w) = Int16Rep#
w
  rebox :: Int16Rep# -> Int16
rebox = Int16Rep# -> Int16
I16#

instance AsmArg Int32 Int32Rep# where
  unbox :: Int32 -> Int32Rep#
unbox (I32# Int32Rep#
w) = Int32Rep#
w
  rebox :: Int32Rep# -> Int32
rebox = Int32Rep# -> Int32
I32#

#if WORD_SIZE_IN_BITS > 32
instance AsmArg Int64 Int64Rep# where
#else
instance AsmArg Int64 Int64# where
#endif
  unbox :: Int64 -> Int#
unbox (I64# Int#
w) = Int#
w
  rebox :: Int# -> Int64
rebox = Int# -> Int64
I64#

instance AsmArg Word Word# where
  unbox :: Word -> Word#
unbox (W# Word#
w) = Word#
w
  rebox :: Word# -> Word
rebox = Word# -> Word
W#

instance AsmArg Word8 Word8Rep# where
  unbox :: Word8 -> Word8Rep#
unbox (W8# Word8Rep#
w) = Word8Rep#
w
  rebox :: Word8Rep# -> Word8
rebox = Word8Rep# -> Word8
W8#

instance AsmArg Word16 Word16Rep# where
  unbox :: Word16 -> Word16Rep#
unbox (W16# Word16Rep#
w) = Word16Rep#
w
  rebox :: Word16Rep# -> Word16
rebox = Word16Rep# -> Word16
W16#

instance AsmArg Word32 Word32Rep# where
  unbox :: Word32 -> Word32Rep#
unbox (W32# Word32Rep#
w) = Word32Rep#
w
  rebox :: Word32Rep# -> Word32
rebox = Word32Rep# -> Word32
W32#

#if WORD_SIZE_IN_BITS > 32
instance AsmArg Word64 Word64Rep# where
#else
instance AsmArg Word64 Word64# where
#endif
  unbox :: Word64 -> Word#
unbox (W64# Word#
w) = Word#
w
  rebox :: Word# -> Word64
rebox = Word# -> Word64
W64#

instance AsmArg Double Double# where
  unbox :: Double -> Double#
unbox (D# Double#
d) = Double#
d
  rebox :: Double# -> Double
rebox = Double# -> Double
D#

instance AsmArg Float Float# where
  unbox :: Float -> Float#
unbox (F# Float#
f) = Float#
f
  rebox :: Float# -> Float
rebox = Float# -> Float
F#

instance AsmArg (Ptr a) Addr# where
  unbox :: Ptr a -> Addr#
unbox (Ptr Addr#
p) = Addr#
p
  rebox :: Addr# -> Ptr a
rebox = forall a. Addr# -> Ptr a
Ptr

replace :: String -> String -> String -> String
replace :: [Char] -> [Char] -> [Char] -> [Char]
replace [Char]
what [Char]
with = [Char] -> [Char]
go
  where
    go :: [Char] -> [Char]
go [] = []
    go str :: [Char]
str@(Char
s:[Char]
ss) | [Char]
what forall a. Eq a => [a] -> [a] -> Bool
`isPrefixOf` [Char]
str = [Char]
with forall a. Semigroup a => a -> a -> a
<> [Char] -> [Char]
go (forall a. Int -> [a] -> [a]
drop (forall (t :: * -> *) a. Foldable t => t a -> Int
length [Char]
what) [Char]
str)
                  | Bool
otherwise = Char
s forall a. a -> [a] -> [a]
: [Char] -> [Char]
go [Char]
ss

data FunKind = Pure | Monadic

defineAsmFunImpl :: AsmCode tyAnn code => FunKind -> String -> tyAnn -> code -> Q [Dec]
defineAsmFunImpl :: forall tyAnn code.
AsmCode tyAnn code =>
FunKind -> [Char] -> tyAnn -> code -> Q [Dec]
defineAsmFunImpl FunKind
kind [Char]
name tyAnn
tyAnn code
asmCode = do
  ForeignSrcLang -> [Char] -> Q ()
addForeignSource ForeignSrcLang
LangAsm forall a b. (a -> b) -> a -> b
$ [[Char]] -> [Char]
unlines [ [Char]
".global " forall a. Semigroup a => a -> a -> a
<> [Char]
asmName
                                     , [Char]
asmName forall a. Semigroup a => a -> a -> a
<> [Char]
":"
                                     , [Char] -> [Char] -> [Char] -> [Char]
replace [Char]
"RET_HASK" [Char]
retToHask forall a b. (a -> b) -> a -> b
$ forall tyAnn code. AsmCode tyAnn code => tyAnn -> code -> [Char]
codeToString tyAnn
tyAnn code
asmCode
                                     , [Char]
retToHask
                                     ]
  Type
funTy <- forall tyAnn code. AsmCode tyAnn code => tyAnn -> Q Type
toTypeQ tyAnn
tyAnn
  (Type
importedTy, Type
sigTy) <- case FunKind
kind of
                              FunKind
Pure -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (Type
funTy, Type
funTy)
                              FunKind
Monadic -> (,) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Type -> Q Type
stateifyUnlifted Type
funTy forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Type -> Q Type
stateifyLifted Type
funTy
  let importedName :: Name
importedName = [Char] -> Name
mkName [Char]
asmName
  Dec
wrapperFunD <- FunKind -> [Char] -> Name -> Type -> Q Dec
mkFunD FunKind
kind [Char]
name Name
importedName Type
funTy
  forall (f :: * -> *) a. Applicative f => a -> f a
pure
    [ Foreign -> Dec
ForeignD forall a b. (a -> b) -> a -> b
$ Callconv -> Safety -> [Char] -> Name -> Type -> Foreign
ImportF Callconv
Prim Safety
Safe [Char]
asmName Name
importedName forall a b. (a -> b) -> a -> b
$ Type -> Type
unliftType Type
importedTy
    , Name -> Type -> Dec
SigD Name
name' Type
sigTy
    , Dec
wrapperFunD
    , Pragma -> Dec
PragmaD forall a b. (a -> b) -> a -> b
$ Name -> Inline -> RuleMatch -> Phases -> Pragma
InlineP Name
name' Inline
Inline RuleMatch
FunLike Phases
AllPhases
    ]
  where
    name' :: Name
name' = [Char] -> Name
mkName [Char]
name
    asmName :: [Char]
asmName = [Char]
name forall a. Semigroup a => a -> a -> a
<> [Char]
"_unlifted"
    retToHask :: [Char]
retToHask = [Char]
"jmp *(%rbp)"

defineAsmFun :: AsmCode tyAnn code => String -> tyAnn -> code -> Q [Dec]
defineAsmFun :: forall tyAnn code.
AsmCode tyAnn code =>
[Char] -> tyAnn -> code -> Q [Dec]
defineAsmFun = forall tyAnn code.
AsmCode tyAnn code =>
FunKind -> [Char] -> tyAnn -> code -> Q [Dec]
defineAsmFunImpl FunKind
Pure

defineAsmFunM :: AsmCode tyAnn code => String -> tyAnn -> code -> Q [Dec]
defineAsmFunM :: forall tyAnn code.
AsmCode tyAnn code =>
[Char] -> tyAnn -> code -> Q [Dec]
defineAsmFunM = forall tyAnn code.
AsmCode tyAnn code =>
FunKind -> [Char] -> tyAnn -> code -> Q [Dec]
defineAsmFunImpl FunKind
Monadic

#if MIN_VERSION_template_haskell(2, 18, 0)
mkPlainTV :: Name -> TyVarBndr Specificity
mkPlainTV :: Name -> TyVarBndr Specificity
mkPlainTV Name
n = forall flag. Name -> flag -> TyVarBndr flag
PlainTV Name
n Specificity
SpecifiedSpec
#else
mkPlainTV :: Name -> TyVarBndr
mkPlainTV = PlainTV
#endif

-- |Converts the wrapped function type to live in a 'PrimMonad':
-- given 'Ty1 -> Ty2 -> Ret' it produces
-- 'forall m. PrimMonad m => Ty1 -> Ty2 -> m Ret'.
stateifyLifted :: Type -> Q Type
stateifyLifted :: Type -> Q Type
stateifyLifted Type
ty = do
  Name
m <- forall (m :: * -> *). Quote m => [Char] -> m Name
newName [Char]
"m"
  [TyVarBndr Specificity] -> Cxt -> Type -> Type
ForallT [Name -> TyVarBndr Specificity
mkPlainTV Name
m] [Type -> Type -> Type
AppT (Name -> Type
ConT ''PrimMonad) (Name -> Type
VarT Name
m)] forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {f :: * -> *}. Quote f => Name -> Type -> f Type
go Name
m Type
ty
  where
    go :: Name -> Type -> f Type
go Name
m (AppT (AppT Type
ArrowT Type
lhs) Type
rhs) = Type -> Type -> Type
AppT (Type -> Type -> Type
AppT Type
ArrowT Type
lhs) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Name -> Type -> f Type
go Name
m Type
rhs
    go Name
m Type
rhs = [t| $(pure $ VarT m) $(pure rhs) |]

-- |Converts the unwrapped/unlifted function type to be a 'primitive' action:
-- given 'Ty1# -> Ty2# -> Ret#' it produces
-- 'forall s. Ty1# -> Ty2# -> State# s -> (# State# s, Ret# #)'.
stateifyUnlifted :: Type -> Q Type
stateifyUnlifted :: Type -> Q Type
stateifyUnlifted Type
ty = do
  Name
s <- forall (m :: * -> *). Quote m => [Char] -> m Name
newName [Char]
"s"
  [TyVarBndr Specificity] -> Cxt -> Type -> Type
ForallT [Name -> TyVarBndr Specificity
mkPlainTV Name
s] [] forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {f :: * -> *}. Quote f => Name -> Type -> f Type
go Name
s Type
ty
  where
    go :: Name -> Type -> f Type
go Name
s (AppT (AppT Type
ArrowT Type
lhs) Type
rhs) = Type -> Type -> Type
AppT (Type -> Type -> Type
AppT Type
ArrowT Type
lhs) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Name -> Type -> f Type
go Name
s Type
rhs
    go Name
s Type
rhs = [t| State# $(pure $ VarT s) -> (# State# $(pure $ VarT s), $(pure rhs) #) |]

mkFunD :: FunKind -> String -> Name -> Type -> Q Dec
mkFunD :: FunKind -> [Char] -> Name -> Type -> Q Dec
mkFunD FunKind
kind [Char]
funName Name
importedName Type
funTy = do
  Name
token <- forall (m :: * -> *). Quote m => [Char] -> m Name
newName [Char]
"token"
  [Name]
argNames <- forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (Type -> Int
countArgs Type
funTy) forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *). Quote m => [Char] -> m Name
newName [Char]
"arg"
  Exp
funAppE <- forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM forall {m :: * -> *}. Quote m => Exp -> (Exp, Type) -> m Exp
f (Name -> Exp
VarE Name
importedName) forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip (Name -> Exp
VarE forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Name]
argNames) (Type -> Cxt
getArgs Type
funTy)
  Exp
fullFunAppE <- case FunKind
kind of
                      FunKind
Pure -> forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp
funAppE
                      FunKind
Monadic -> [e| $(pure funAppE) $(pure $ VarE token) |]

  Exp
body <- case Type -> Maybe Int
detectRetTuple Type
funTy of
               Maybe Int
Nothing ->
                 case FunKind
kind of
                      FunKind
Pure ->
                        [e| rebox $(pure fullFunAppE) |]
                      FunKind
Monadic ->
                        [e| case $(pure fullFunAppE) of
                                 (# token', res #) -> (# token', rebox res #)
                          |]
               Just Int
n -> do
                  [Name]
retNames <- forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
n forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *). Quote m => [Char] -> m Name
newName [Char]
"ret"
                  [Maybe Exp]
boxing <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Name]
retNames forall a b. (a -> b) -> a -> b
$ \Name
name -> forall a. a -> Maybe a
Just forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [e| rebox $(pure $ VarE name) |]
                  case FunKind
kind of
                       FunKind
Pure ->
                          [e| case $(pure fullFunAppE) of
                                   $(pure $ UnboxedTupP $ VarP <$> retNames) -> $(pure $ TupE boxing)
                            |]
                       FunKind
Monadic ->
                          [e| case $(pure fullFunAppE) of
                                   (# token', $(pure $ UnboxedTupP $ VarP <$> retNames) #) -> (# token', $(pure $ TupE boxing) #)
                            |]

  Exp
body' <- case FunKind
kind of
                FunKind
Pure -> forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp
body
                FunKind
Monadic -> [e| primitive (\ $(pure $ VarP token) -> $(pure body)) |]
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Name -> [Clause] -> Dec
FunD ([Char] -> Name
mkName [Char]
funName) [[Pat] -> Body -> [Dec] -> Clause
Clause (Name -> Pat
VarP forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Name]
argNames) (Exp -> Body
NormalB Exp
body') []]
  where
    f :: Exp -> (Exp, Type) -> m Exp
f Exp
acc (Exp
argName, Type
argType) | Type
argType forall a. Eq a => a -> a -> Bool
== Name -> Type
ConT ''BS.ByteString = [e| $(pure acc)
                                                                            (unbox $ getBSAddr $(pure argName))
                                                                            (unbox $ BS.length $(pure argName))
                                                                   |]
                             | Bool
otherwise = [e| $(pure acc) (unbox $(pure argName)) |]

{-# NOINLINE unliftType #-}
unliftType :: Type -> Type
unliftType :: Type -> Type
unliftType = forall from to. Biplate from to => (to -> to) -> from -> from
transformBi Type -> Type
unliftTuple
           forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall from to. Biplate from to => (to -> to) -> from -> from
transformBi Name -> Name
unliftBaseTy
           forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall from to. Biplate from to => (to -> to) -> from -> from
transformBi Type -> Type
unliftPtrs
           forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall from to. Biplate from to => (to -> to) -> from -> from
transformBi Type -> Type
unliftBS
  where
    unliftBaseTy :: Name -> Name
unliftBaseTy Name
x
#if MIN_VERSION_GLASGOW_HASKELL(9,4,0,0)
      | x == ''Word64 = ''Word64#
      | x == ''Int64 = ''Int64#
#else
      | Name
x forall a. Eq a => a -> a -> Bool
== ''Word64 = ''Word#
      | Name
x forall a. Eq a => a -> a -> Bool
== ''Int64 = ''Int#
#endif
#if MIN_VERSION_GLASGOW_HASKELL(9,2,0,0)
      | Name
x forall a. Eq a => a -> a -> Bool
== ''Word = ''Word#
      | Name
x forall a. Eq a => a -> a -> Bool
== ''Word8 = ''Word8#
      | Name
x forall a. Eq a => a -> a -> Bool
== ''Word16 = ''Word16#
      | Name
x forall a. Eq a => a -> a -> Bool
== ''Word32 = ''Word32#

      | Name
x forall a. Eq a => a -> a -> Bool
== ''Int = ''Int#
      | Name
x forall a. Eq a => a -> a -> Bool
== ''Int8 = ''Int8#
      | Name
x forall a. Eq a => a -> a -> Bool
== ''Int16 = ''Int16#
      | Name
x forall a. Eq a => a -> a -> Bool
== ''Int32 = ''Int32#

#else
      | x `elem` [ ''Word, ''Word8, ''Word16, ''Word32, ''Word64 ] = ''Word#
      | x `elem` [ ''Int, ''Int8, ''Int16, ''Int32, ''Int64 ] = ''Int#
#endif
      | Name
x forall a. Eq a => a -> a -> Bool
== ''Double = ''Double#
      | Name
x forall a. Eq a => a -> a -> Bool
== ''Float = ''Float#
      | Name
x forall a. Eq a => a -> a -> Bool
== ''Unit = ''Int#
      | Bool
otherwise = Name
x

    unliftPtrs :: Type -> Type
unliftPtrs (AppT (ConT Name
name) Type
_) | Name
name forall a. Eq a => a -> a -> Bool
== ''Ptr = Name -> Type
ConT ''Addr#
    unliftPtrs Type
x = Type
x

    unliftBS :: Type -> Type
unliftBS (AppT (AppT Type
ArrowT (ConT Name
bs)) Type
rhs) | Name
bs forall a. Eq a => a -> a -> Bool
== ''BS.ByteString = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. Quasi m => Q a -> m a
runQ [t| Addr# -> Int# -> $(pure rhs) |]
    unliftBS Type
x = Type
x

    unliftTuple :: Type -> Type
unliftTuple (TupleT Int
n) = Int -> Type
UnboxedTupleT Int
n
    unliftTuple Type
x = Type
x

detectRetTuple :: Type -> Maybe Int
detectRetTuple :: Type -> Maybe Int
detectRetTuple (AppT (AppT Type
ArrowT Type
_) Type
rhs) = Type -> Maybe Int
detectRetTuple Type
rhs
detectRetTuple (AppT Type
lhs Type
_) = Type -> Maybe Int
detectRetTuple Type
lhs
detectRetTuple (TupleT Int
n) = forall a. a -> Maybe a
Just Int
n
detectRetTuple Type
_ = forall a. Maybe a
Nothing

getArgs :: Type -> [Type]
getArgs :: Type -> Cxt
getArgs Type
ty = [ Type
argTy | AppT Type
ArrowT Type
argTy <- forall from to. Biplate from to => from -> [to]
universeBi Type
ty ]

countArgs :: Type -> Int
countArgs :: Type -> Int
countArgs Type
ty = forall (t :: * -> *) a. Foldable t => t a -> Int
length forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool) -> [a] -> [a]
filter (forall a. Eq a => a -> a -> Bool
== Type
ArrowT) forall a b. (a -> b) -> a -> b
$ forall from to. Biplate from to => from -> [to]
universeBi Type
ty