{-# LANGUAGE CPP #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module Language.C.Inline.Context
(
TypesTable
, Purity(..)
, convertType
, CArray
, typeNamesFromTypesTable
, AntiQuoter(..)
, AntiQuoterId
, SomeAntiQuoter(..)
, AntiQuoters
, Context(..)
, baseCtx
, fptrCtx
, funCtx
, vecCtx
, VecCtx(..)
, bsCtx
) where
import Control.Applicative ((<|>))
import Control.Monad (mzero, forM)
import Control.Monad.Trans.Class (lift)
import Control.Monad.Trans.Maybe (MaybeT, runMaybeT)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Unsafe as BS
import Data.Coerce
import Data.Int (Int8, Int16, Int32, Int64)
import qualified Data.Map as Map
import Data.Typeable (Typeable)
import qualified Data.Vector.Storable as V
import qualified Data.Vector.Storable.Mutable as VM
import Data.Word (Word8, Word16, Word32, Word64)
import Foreign.C.Types
import Foreign.ForeignPtr (withForeignPtr)
import Foreign.Ptr (Ptr, FunPtr, freeHaskellFunPtr)
import Foreign.Storable (Storable)
import qualified Language.Haskell.TH as TH
import qualified Language.Haskell.TH.Syntax as TH
import qualified Text.Parser.Token as Parser
import qualified Data.HashSet as HashSet
#if MIN_VERSION_base(4,9,0)
import Data.Semigroup (Semigroup, (<>))
#else
import Data.Monoid ((<>))
#endif
#if __GLASGOW_HASKELL__ < 710
import Data.Monoid (Monoid(..))
import Data.Traversable (traverse)
#endif
import Language.C.Inline.FunPtr
import qualified Language.C.Types as C
import Language.C.Inline.HaskellIdentifier
type TypesTable = Map.Map C.TypeSpecifier TH.TypeQ
data Purity
= Pure
| IO
deriving (Eq, Show)
data AntiQuoter a = AntiQuoter
{ aqParser :: forall m. C.CParser HaskellIdentifier m => m (C.CIdentifier, C.Type C.CIdentifier, a)
, aqMarshaller :: Purity -> TypesTable -> C.Type C.CIdentifier -> a -> TH.Q (TH.Type, TH.Exp)
}
type AntiQuoterId = String
data SomeAntiQuoter = forall a. (Eq a, Typeable a) => SomeAntiQuoter (AntiQuoter a)
type AntiQuoters = Map.Map AntiQuoterId SomeAntiQuoter
data Context = Context
{ ctxTypesTable :: TypesTable
, ctxAntiQuoters :: AntiQuoters
, ctxOutput :: Maybe (String -> String)
, ctxForeignSrcLang :: Maybe TH.ForeignSrcLang
, ctxEnableCpp :: Bool
}
#if MIN_VERSION_base(4,9,0)
instance Semigroup Context where
ctx2 <> ctx1 = Context
{ ctxTypesTable = ctxTypesTable ctx1 <> ctxTypesTable ctx2
, ctxAntiQuoters = ctxAntiQuoters ctx1 <> ctxAntiQuoters ctx2
, ctxOutput = ctxOutput ctx1 <|> ctxOutput ctx2
, ctxForeignSrcLang = ctxForeignSrcLang ctx1 <|> ctxForeignSrcLang ctx2
, ctxEnableCpp = ctxEnableCpp ctx1 || ctxEnableCpp ctx2
}
#endif
instance Monoid Context where
mempty = Context
{ ctxTypesTable = mempty
, ctxAntiQuoters = mempty
, ctxOutput = Nothing
, ctxForeignSrcLang = Nothing
, ctxEnableCpp = False
}
#if !MIN_VERSION_base(4,11,0)
mappend ctx2 ctx1 = Context
{ ctxTypesTable = ctxTypesTable ctx1 <> ctxTypesTable ctx2
, ctxAntiQuoters = ctxAntiQuoters ctx1 <> ctxAntiQuoters ctx2
, ctxOutput = ctxOutput ctx1 <|> ctxOutput ctx2
, ctxForeignSrcLang = ctxForeignSrcLang ctx1 <|> ctxForeignSrcLang ctx2
, ctxEnableCpp = ctxEnableCpp ctx1 || ctxEnableCpp ctx2
}
#endif
baseCtx :: Context
baseCtx = mempty
{ ctxTypesTable = baseTypesTable
}
baseTypesTable :: Map.Map C.TypeSpecifier TH.TypeQ
baseTypesTable = Map.fromList
[ (C.Void, [t| () |])
, (C.Bool, [t| CBool |])
, (C.Char Nothing, [t| CChar |])
, (C.Char (Just C.Signed), [t| CSChar |])
, (C.Char (Just C.Unsigned), [t| CUChar |])
, (C.Short C.Signed, [t| CShort |])
, (C.Short C.Unsigned, [t| CUShort |])
, (C.Int C.Signed, [t| CInt |])
, (C.Int C.Unsigned, [t| CUInt |])
, (C.Long C.Signed, [t| CLong |])
, (C.Long C.Unsigned, [t| CULong |])
, (C.TypeName "ptrdiff_t", [t| CPtrdiff |])
, (C.TypeName "size_t", [t| CSize |])
, (C.TypeName "wchar_t", [t| CWchar |])
, (C.TypeName "sig_atomic_t", [t| CSigAtomic |])
, (C.LLong C.Signed, [t| CLLong |])
, (C.LLong C.Unsigned, [t| CULLong |])
, (C.TypeName "intptr_t", [t| CIntPtr |])
, (C.TypeName "uintptr_t", [t| CUIntPtr |])
, (C.TypeName "intmax_t", [t| CIntMax |])
, (C.TypeName "uintmax_t", [t| CUIntMax |])
, (C.TypeName "clock_t", [t| CClock |])
, (C.TypeName "time_t", [t| CTime |])
, (C.TypeName "useconds_t", [t| CUSeconds |])
, (C.TypeName "suseconds_t", [t| CSUSeconds |])
, (C.Float, [t| CFloat |])
, (C.Double, [t| CDouble |])
, (C.TypeName "FILE", [t| CFile |])
, (C.TypeName "fpos_t", [t| CFpos |])
, (C.TypeName "jmp_buf", [t| CJmpBuf |])
, (C.TypeName "int8_t", [t| Int8 |])
, (C.TypeName "int16_t", [t| Int16 |])
, (C.TypeName "int32_t", [t| Int32 |])
, (C.TypeName "int64_t", [t| Int64 |])
, (C.TypeName "uint8_t", [t| Word8 |])
, (C.TypeName "uint16_t", [t| Word16 |])
, (C.TypeName "uint32_t", [t| Word32 |])
, (C.TypeName "uint64_t", [t| Word64 |])
]
type CArray = Ptr
convertType
:: Purity
-> TypesTable
-> C.Type C.CIdentifier
-> TH.Q (Maybe TH.Type)
convertType purity cTypes = runMaybeT . go
where
goDecl = go . C.parameterDeclarationType
go :: C.Type C.CIdentifier -> MaybeT TH.Q TH.Type
go cTy = do
case cTy of
C.TypeSpecifier _specs (C.Template ident' cTys) -> do
symbol <- case Map.lookup (C.TypeName ident') cTypes of
Nothing -> mzero
Just ty -> return ty
hsTy <- forM cTys $ \cTys' -> go (C.TypeSpecifier undefined cTys')
case hsTy of
(a:[]) ->
lift [t| $(symbol) $(return a) |]
(a:b:[]) ->
lift [t| $(symbol) '($(return a),$(return b))|]
(a:b:c:[]) ->
lift [t| $(symbol) '($(return a),$(return b),$(return c))|]
(a:b:c:d:[]) ->
lift [t| $(symbol) '($(return a),$(return b),$(return c),$(return d))|]
(a:b:c:d:e:[]) ->
lift [t| $(symbol) '($(return a),$(return b),$(return c),$(return d),$(return e))|]
[] -> fail $ "Can not find template parameters."
_ -> fail $ "Find too many template parameters. num = " ++ show (length hsTy)
C.TypeSpecifier _specs (C.TemplateConst num) -> do
let n = (TH.LitT (TH.NumTyLit (read num)))
lift [t| $(return n) |]
C.TypeSpecifier _specs cSpec ->
case Map.lookup cSpec cTypes of
Nothing -> mzero
Just ty -> lift ty
C.Ptr _quals (C.Proto retType pars) -> do
hsRetType <- go retType
hsPars <- mapM goDecl pars
lift [t| FunPtr $(buildArr hsPars hsRetType) |]
C.Ptr _quals cTy' -> do
hsTy <- go cTy'
lift [t| Ptr $(return hsTy) |]
C.Array _mbSize cTy' -> do
hsTy <- go cTy'
lift [t| CArray $(return hsTy) |]
C.Proto _retType _pars -> do
mzero
buildArr [] hsRetType =
case purity of
Pure -> [t| $(return hsRetType) |]
IO -> [t| IO $(return hsRetType) |]
buildArr (hsPar : hsPars) hsRetType =
[t| $(return hsPar) -> $(buildArr hsPars hsRetType) |]
typeNamesFromTypesTable :: TypesTable -> C.TypeNames
typeNamesFromTypesTable cTypes = HashSet.fromList
[ id' | C.TypeName id' <- Map.keys cTypes ]
getHsVariable :: String -> HaskellIdentifier -> TH.ExpQ
getHsVariable err s = do
mbHsName <- TH.lookupValueName $ unHaskellIdentifier s
case mbHsName of
Nothing -> fail $ "Cannot capture Haskell variable " ++ unHaskellIdentifier s ++
", because it's not in scope. (" ++ err ++ ")"
Just hsName -> TH.varE hsName
convertType_ :: String -> Purity -> TypesTable -> C.Type C.CIdentifier -> TH.Q TH.Type
convertType_ err purity cTypes cTy = do
mbHsType <- convertType purity cTypes cTy
case mbHsType of
Nothing -> fail $ "Cannot convert C type (" ++ err ++ ")"
Just hsType -> return hsType
fptrCtx :: Context
fptrCtx = mempty
{ ctxAntiQuoters = Map.fromList [("fptr-ptr", SomeAntiQuoter fptrAntiQuoter)]
}
fptrAntiQuoter :: AntiQuoter HaskellIdentifier
fptrAntiQuoter = AntiQuoter
{ aqParser = cDeclAqParser
, aqMarshaller = \purity cTypes cTy cId -> do
hsTy <- convertType_ "fptrCtx" purity cTypes cTy
hsExp <- getHsVariable "fptrCtx" cId
hsExp' <- [| withForeignPtr (coerce $(return hsExp)) |]
return (hsTy, hsExp')
}
funCtx :: Context
funCtx = mempty
{ ctxAntiQuoters = Map.fromList [("fun", SomeAntiQuoter funPtrAntiQuoter)
,("fun-alloc", SomeAntiQuoter funAllocPtrAntiQuoter)]
}
funPtrAntiQuoter :: AntiQuoter HaskellIdentifier
funPtrAntiQuoter = AntiQuoter
{ aqParser = cDeclAqParser
, aqMarshaller = \purity cTypes cTy cId -> do
hsTy <- convertType_ "funCtx" purity cTypes cTy
hsExp <- getHsVariable "funCtx" cId
case hsTy of
TH.AppT (TH.ConT n) hsTy' | n == ''FunPtr -> do
hsExp' <- [| \cont -> do
funPtr <- $(mkFunPtr (return hsTy')) $(return hsExp)
x <- cont funPtr
freeHaskellFunPtr funPtr
return x
|]
return (hsTy, hsExp')
_ -> fail "The `fun' marshaller captures function pointers only"
}
funAllocPtrAntiQuoter :: AntiQuoter HaskellIdentifier
funAllocPtrAntiQuoter = AntiQuoter
{ aqParser = cDeclAqParser
, aqMarshaller = \purity cTypes cTy cId -> do
hsTy <- convertType_ "funCtx" purity cTypes cTy
hsExp <- getHsVariable "funCtx" cId
case hsTy of
TH.AppT (TH.ConT n) hsTy' | n == ''FunPtr -> do
hsExp' <- [| \cont -> do
funPtr <- $(mkFunPtr (return hsTy')) $(return hsExp)
cont funPtr
|]
return (hsTy, hsExp')
_ -> fail "The `fun-alloc' marshaller captures function pointers only"
}
vecCtx :: Context
vecCtx = mempty
{ ctxAntiQuoters = Map.fromList
[ ("vec-ptr", SomeAntiQuoter vecPtrAntiQuoter)
, ("vec-len", SomeAntiQuoter vecLenAntiQuoter)
]
}
class VecCtx a where
type VecCtxScalar a :: *
vecCtxLength :: a -> Int
vecCtxUnsafeWith :: a -> (Ptr (VecCtxScalar a) -> IO b) -> IO b
instance Storable a => VecCtx (V.Vector a) where
type VecCtxScalar (V.Vector a) = a
vecCtxLength = V.length
vecCtxUnsafeWith = V.unsafeWith
instance Storable a => VecCtx (VM.IOVector a) where
type VecCtxScalar (VM.IOVector a) = a
vecCtxLength = VM.length
vecCtxUnsafeWith = VM.unsafeWith
vecPtrAntiQuoter :: AntiQuoter HaskellIdentifier
vecPtrAntiQuoter = AntiQuoter
{ aqParser = cDeclAqParser
, aqMarshaller = \purity cTypes cTy cId -> do
hsTy <- convertType_ "vecCtx" purity cTypes cTy
hsExp <- getHsVariable "vecCtx" cId
hsExp' <- [| vecCtxUnsafeWith $(return hsExp) |]
return (hsTy, hsExp')
}
vecLenAntiQuoter :: AntiQuoter HaskellIdentifier
vecLenAntiQuoter = AntiQuoter
{ aqParser = do
hId <- C.parseIdentifier
useCpp <- C.parseEnableCpp
let cId = mangleHaskellIdentifier useCpp hId
return (cId, C.TypeSpecifier mempty (C.Long C.Signed), hId)
, aqMarshaller = \_purity _cTypes cTy cId -> do
case cTy of
C.TypeSpecifier _ (C.Long C.Signed) -> do
hsExp <- getHsVariable "vecCtx" cId
hsExp' <- [| fromIntegral (vecCtxLength $(return hsExp)) |]
hsTy <- [t| CLong |]
hsExp'' <- [| \cont -> cont $(return hsExp') |]
return (hsTy, hsExp'')
_ -> do
fail "impossible: got type different from `long' (vecCtx)"
}
bsCtx :: Context
bsCtx = mempty
{ ctxAntiQuoters = Map.fromList
[ ("bs-ptr", SomeAntiQuoter bsPtrAntiQuoter)
, ("bs-len", SomeAntiQuoter bsLenAntiQuoter)
, ("bs-cstr", SomeAntiQuoter bsCStrAntiQuoter)
]
}
bsPtrAntiQuoter :: AntiQuoter HaskellIdentifier
bsPtrAntiQuoter = AntiQuoter
{ aqParser = do
hId <- C.parseIdentifier
useCpp <- C.parseEnableCpp
let cId = mangleHaskellIdentifier useCpp hId
return (cId, C.Ptr [] (C.TypeSpecifier mempty (C.Char Nothing)), hId)
, aqMarshaller = \_purity _cTypes cTy cId -> do
case cTy of
C.Ptr _ (C.TypeSpecifier _ (C.Char Nothing)) -> do
hsTy <- [t| Ptr CChar |]
hsExp <- getHsVariable "bsCtx" cId
hsExp' <- [| \cont -> BS.unsafeUseAsCString $(return hsExp) $ \ptr -> cont ptr |]
return (hsTy, hsExp')
_ ->
fail "impossible: got type different from `char *' (bsCtx)"
}
bsLenAntiQuoter :: AntiQuoter HaskellIdentifier
bsLenAntiQuoter = AntiQuoter
{ aqParser = do
hId <- C.parseIdentifier
useCpp <- C.parseEnableCpp
let cId = mangleHaskellIdentifier useCpp hId
return (cId, C.TypeSpecifier mempty (C.Long C.Signed), hId)
, aqMarshaller = \_purity _cTypes cTy cId -> do
case cTy of
C.TypeSpecifier _ (C.Long C.Signed) -> do
hsExp <- getHsVariable "bsCtx" cId
hsExp' <- [| fromIntegral (BS.length $(return hsExp)) |]
hsTy <- [t| CLong |]
hsExp'' <- [| \cont -> cont $(return hsExp') |]
return (hsTy, hsExp'')
_ -> do
fail "impossible: got type different from `long' (bsCtx)"
}
bsCStrAntiQuoter :: AntiQuoter HaskellIdentifier
bsCStrAntiQuoter = AntiQuoter
{ aqParser = do
hId <- C.parseIdentifier
useCpp <- C.parseEnableCpp
let cId = mangleHaskellIdentifier useCpp hId
return (cId, C.Ptr [] (C.TypeSpecifier mempty (C.Char Nothing)), hId)
, aqMarshaller = \_purity _cTypes cTy cId -> do
case cTy of
C.Ptr _ (C.TypeSpecifier _ (C.Char Nothing)) -> do
hsTy <- [t| Ptr CChar |]
hsExp <- getHsVariable "bsCtx" cId
hsExp' <- [| \cont -> BS.useAsCString $(return hsExp) $ \ptr -> cont ptr |]
return (hsTy, hsExp')
_ ->
fail "impossible: got type different from `char *' (bsCtx)"
}
cDeclAqParser
:: C.CParser HaskellIdentifier m
=> m (C.CIdentifier, C.Type C.CIdentifier, HaskellIdentifier)
cDeclAqParser = do
cTy <- Parser.parens C.parseParameterDeclaration
useCpp <- C.parseEnableCpp
case C.parameterDeclarationId cTy of
Nothing -> fail "Every captured function must be named (funCtx)"
Just hId -> do
let cId = mangleHaskellIdentifier useCpp hId
cTy' <- deHaskellifyCType $ C.parameterDeclarationType cTy
return (cId, cTy', hId)
deHaskellifyCType
:: C.CParser HaskellIdentifier m
=> C.Type HaskellIdentifier -> m (C.Type C.CIdentifier)
deHaskellifyCType = traverse $ \hId -> do
useCpp <- C.parseEnableCpp
case C.cIdentifierFromString useCpp (unHaskellIdentifier hId) of
Left err -> fail $ "Illegal Haskell identifier " ++ unHaskellIdentifier hId ++
" in C type:\n" ++ err
Right x -> return x