{-# LANGUAGE CPP #-}
{-# LANGUAGE TemplateHaskell #-}
module Data.Struct.TH (makeStruct) where
import Control.Monad (when, zipWithM)
import Control.Monad.Primitive (PrimMonad, PrimState)
import Data.Either (partitionEithers)
import Data.Primitive
import Data.Struct
import Data.Struct.Internal (Dict(Dict), initializeUnboxedField, st)
import Data.List (groupBy, nub)
import Language.Haskell.TH
import Language.Haskell.TH.Datatype.TyVarBndr
import Language.Haskell.TH.Syntax (VarStrictType)
#ifdef HLINT
{-# ANN module "HLint: ignore Use ." #-}
#endif
data StructRep = StructRep
{ srState :: Name
, srName :: Name
, srTyVars :: [TyVarBndrUnit]
#if MIN_VERSION_template_haskell(2,12,0)
, srDerived :: [DerivClause]
#else
, srDerived :: Cxt
#endif
, srCxt :: Cxt
, srConstructor :: Name
, srMembers :: [Member]
} deriving Show
data Member = Member
{ _memberRep :: Representation
, memberName :: Name
, _memberType :: Type
}
deriving Show
data Representation = BoxedField | UnboxedField | Slot
deriving Show
-- | Generate allocators, slots, fields, unboxed fields, Eq instances,
-- and Struct instances for the given "data types".
--
-- Inputs are expected to be "data types" parameterized by a state
-- type. Strict fields are considered to be slots, Non-strict fields
-- are considered to be boxed types, Unpacked fields are considered
-- to be unboxed primitives.
--
-- The data type should use record syntax and have a single constructor.
-- The field names will be used to generate slot, field, and unboxedField
-- values of the same name.
--
-- An allocator for the struct is generated by prefixing "alloc" to the
-- data type name.
makeStruct :: DecsQ -> DecsQ
makeStruct dsq =
do ds <- dsq
(passthrough, reps) <- partitionEithers <$> traverse computeRep ds
ds's <- traverse (generateCode passthrough) reps
return (passthrough ++ concat ds's)
mkAllocName :: StructRep -> Name
mkAllocName rep = mkName ("alloc" ++ nameBase (srName rep))
mkInitName :: StructRep -> Name
mkInitName rep = mkName ("new" ++ nameBase (srName rep))
------------------------------------------------------------------------
-- Input validation
------------------------------------------------------------------------
computeRep :: Dec -> Q (Either Dec StructRep)
computeRep (DataD c n vs _ cs ds) =
do state <- validateStateType vs
(conname, confields) <- validateContructor cs
members <- traverse (validateMember state) confields
return $ Right StructRep
{ srState = state
, srName = n
, srTyVars = vs
, srConstructor = conname
, srMembers = members
, srDerived = ds
, srCxt = c
}
computeRep d = return (Left d)
-- | Check that only a single data constructor was provided and
-- that it was a record constructor.
validateContructor :: [Con] -> Q (Name,[VarStrictType])
validateContructor [RecC name fields] = return (name,fields)
validateContructor [_] = fail "Expected a record constructor"
validateContructor xs = fail ("Expected 1 constructor, got " ++ show (length xs))
-- A struct type's final type variable should be suitable for
-- use as the ('PrimState' m) argument.
validateStateType :: [TyVarBndrUnit] -> Q Name
validateStateType xs =
do when (null xs) (fail "state type expected but no type variables found")
elimTV return validateKindedTV (last xs)
where
validateKindedTV :: Name -> Kind -> Q Name
validateKindedTV n k
| k == starK = return n
| otherwise = fail "state type should have kind *"
-- | Figure out which record fields are Slots and which are
-- Fields. Slots will have types ending in the state type
validateMember :: Name -> VarStrictType -> Q Member
validateMember s (fieldname,Bang NoSourceUnpackedness NoSourceStrictness,fieldtype) =
do when (occurs s fieldtype)
(fail ("state type may not occur in field `" ++ nameBase fieldname ++ "`"))
return (Member BoxedField fieldname fieldtype)
validateMember s (fieldname,Bang NoSourceUnpackedness SourceStrict,fieldtype) =
do f <- unapplyType fieldtype s
when (occurs s f)
(fail ("state type may only occur in final position in slot `" ++ nameBase fieldname ++ "`"))
return (Member Slot fieldname f)
validateMember s (fieldname,Bang SourceUnpack SourceStrict,fieldtype) =
do when (occurs s fieldtype)
(fail ("state type may not occur in unpacked field `" ++ nameBase fieldname ++ "`"))
return (Member UnboxedField fieldname fieldtype)
validateMember _ _ = fail "validateMember: can't unpack nonstrict fields"
unapplyType :: Type -> Name -> Q Type
unapplyType (AppT f (VarT x)) y | x == y = return f
unapplyType t n =
fail $ "Unable to match state type of slot: " ++ show t ++ " | expected: " ++ nameBase n
------------------------------------------------------------------------
-- Code generation
------------------------------------------------------------------------
generateCode :: [Dec] -> StructRep -> DecsQ
generateCode ds rep = concat <$> sequence
[ generateDataType rep
, generateStructInstance rep
, generateMembers rep
, generateNew rep
, generateAlloc rep
, generateRoles ds rep
]
-- Generates: newtype TyCon a b c s = DataCon (Object s)
generateDataType :: StructRep -> DecsQ
generateDataType rep = sequence
[ newtypeD (return (srCxt rep)) (srName rep) (srTyVars rep)
Nothing
(normalC
(srConstructor rep)
[ bangType
(bang noSourceUnpackedness noSourceStrictness)
[t| Object $(varT (srState rep)) |]
])
#if MIN_VERSION_template_haskell(2,12,0)
(map return (srDerived rep))
#else
(return (srDerived rep))
#endif
]
generateRoles :: [Dec] -> StructRep -> DecsQ
generateRoles ds rep
| hasRoleAnnotation = return []
| otherwise = sequence [ roleAnnotD (srName rep) (computeRoles rep) ]
where
hasRoleAnnotation = any isTargetRoleAnnot ds
isTargetRoleAnnot (RoleAnnotD n _) = n == srName rep
isTargetRoleAnnot _ = False
computeRoles :: StructRep -> [Role]
computeRoles = map (const NominalR) . srTyVars
repType1 :: StructRep -> TypeQ
repType1 rep = repTypeHelper (srName rep) (init (srTyVars rep))
repType :: StructRep -> TypeQ
repType rep = repTypeHelper (srName rep) (srTyVars rep)
repTypeHelper :: Name -> [TyVarBndrUnit] -> TypeQ
repTypeHelper c vs = foldl appT (conT c) (tyVarBndrT <$> vs)
tyVarBndrT :: TyVarBndrUnit -> TypeQ
tyVarBndrT = elimTV varT (sigT . varT)
generateStructInstance :: StructRep -> DecsQ
generateStructInstance rep =
[d| instance Struct $(repType1 rep) where struct = Dict
instance Eq $(repType rep) where (==) = eqStruct
|]
generateAlloc :: StructRep -> DecsQ
generateAlloc rep =
do mName <- newName "m"
let m :: TypeQ
m = varT mName
n = length (groupBy isNeighbor (srMembers rep))
allocName = mkAllocName rep
simpleDefinition rep allocName
(forallT [plainTVSpecified mName] (cxt [])
[t| PrimMonad $m => $m ( $(repType1 rep) (PrimState $m) ) |])
[| alloc n |]
generateNew :: StructRep -> DecsQ
generateNew rep =
do this <- newName "this"
let ms = groupBy isNeighbor (srMembers rep)
addName m = do n <- newName (nameBase (memberName m))
return (n,m)
msWithArgs <- traverse (traverse addName) ms
let name = mkInitName rep
body = doE
$ bindS (varP this) (varE (mkAllocName rep))
: (noBindS <$> zipWith (assignN (varE this)) [0..] msWithArgs)
++ [ noBindS [| return $(varE this) |] ]
sequence
[ sigD name (newStructType rep)
, funD name [ clause (varP . fst <$> concat msWithArgs)
(normalB [| st $body |] ) [] ]
]
assignN :: ExpQ -> Int -> [(Name,Member)] -> ExpQ
assignN this _ [(arg,Member BoxedField n _)] =
[| setField $(varE n) $this $(varE arg) |]
assignN this _ [(arg,Member Slot n _)] =
[| set $(varE n) $this $(varE arg)|]
assignN this i us =
do let n = length us
mba <- newName "mba"
let arg0 = fst (head us)
doE $ bindS (varP mba) [| initializeUnboxedField i n (sizeOf $(varE arg0)) $this |]
: [ noBindS [| writeByteArray $(varE mba) j $(varE arg) |]
| (j,(arg,_)) <- zip [0 :: Int ..] us ]
newStructType :: StructRep -> TypeQ
newStructType rep =
do mName <- newName "m"
let m :: TypeQ
m = varT mName
s = [t| PrimState $m |]
obj = repType1 rep
buildType (Member BoxedField _ t) = return t
buildType (Member UnboxedField _ t) = return t
buildType (Member Slot _ f) = [t| $(return f) $s |]
r = foldr (-->)
[t| $m ($obj $s) |]
(buildType <$> srMembers rep)
primPreds = primPred <$> nub [ t | Member UnboxedField _ (VarT t) <- srMembers rep ]
forallRepT rep $ forallT [plainTVSpecified mName] (cxt primPreds)
[t| PrimMonad $m => $r |]
generateMembers :: StructRep -> DecsQ
generateMembers rep
= concat <$>
zipWithM
(generateMember1 rep)
[0..]
(groupBy isNeighbor (srMembers rep))
isNeighbor :: Member -> Member -> Bool
isNeighbor (Member UnboxedField _ t) (Member UnboxedField _ u) = t == u
isNeighbor _ _ = False
generateMember1 :: StructRep -> Int -> [Member] -> DecsQ
generateMember1 rep n [Member BoxedField fieldname fieldtype] =
simpleDefinition rep fieldname
[t| Field $(repType1 rep) $(return fieldtype) |]
[| field n |]
generateMember1 rep n [Member Slot slotname slottype] =
simpleDefinition rep slotname
[t| Slot $(repType1 rep) $(return slottype) |]
[| slot n |]
generateMember1 rep n us =
concat <$> sequence
[ simpleDefinition rep fieldname
(addPrimCxt fieldtype
[t| Field $(repType1 rep) $(return fieldtype) |])
[| unboxedField n i |]
| (i,Member UnboxedField fieldname fieldtype) <- zip [0 :: Int ..] us
]
where
addPrimCxt (VarT t) = forallT [] (cxt [primPred t])
addPrimCxt _ = id
simpleDefinition :: StructRep -> Name -> TypeQ -> ExpQ -> DecsQ
simpleDefinition rep name typ def =
sequence
[ sigD name (forallRepT rep typ)
, simpleValD name def
, pragInlD name Inline FunLike AllPhases
]
simpleValD :: Name -> ExpQ -> DecQ
simpleValD var val = valD (varP var) (normalB val) []
forallRepT :: StructRep -> TypeQ -> TypeQ
forallRepT rep = forallT (init (changeTVFlags SpecifiedSpec (srTyVars rep))) (cxt [])
(-->) :: TypeQ -> TypeQ -> TypeQ
f --> x = arrowT `appT` f `appT` x
primPred :: Name -> PredQ
primPred t = [t| Prim $(varT t) |]
occurs :: Name -> Type -> Bool
occurs n (AppT f x) = occurs n f || occurs n x
occurs n (VarT m) = n == m
occurs n (ForallT _ _ t) = occurs n t
occurs n (SigT t _) = occurs n t
occurs _ _ = False