{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE ParallelListComp #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
module Data.Store.TH.Internal
(
deriveManyStoreFromStorable
, deriveTupleStoreInstance
, deriveGenericInstance
, deriveManyStorePrimVector
, deriveManyStoreUnboxVector
, deriveStore
, makeStore
, getAllInstanceTypes1
, isMonoType
) where
import Control.Applicative
import Data.Complex ()
import Data.Generics.Aliases (extT)
import Data.Generics.Schemes (listify, everywhere)
import Data.List (find, nub)
import qualified Data.Map as M
import Data.Maybe (fromMaybe)
import Data.Primitive.ByteArray
import Data.Primitive.Types
import Data.Store.Core
import Data.Store.Impl
import qualified Data.Text as T
import Data.Traversable (forM)
import qualified Data.Vector.Primitive as PV
import qualified Data.Vector.Unboxed as UV
import Data.Word
import Foreign.Storable (Storable)
import GHC.Types (Int(..))
import Language.Haskell.TH
import Language.Haskell.TH.ReifyMany.Internal (TypeclassInstance(..), getInstances, unAppsT)
import Language.Haskell.TH.Syntax (lift)
import Prelude
import Safe (headMay)
import TH.Derive (Deriver(..))
import TH.ReifySimple
import TH.Utilities (expectTyCon1, dequalify, plainInstanceD, appsT)
instance Deriver (Store a) where
runDeriver _ preds ty = do
argTy <- expectTyCon1 ''Store ty
dt <- reifyDataTypeSubstituted argTy
(:[]) <$> deriveStore preds argTy (dtCons dt)
makeStore :: Name -> Q [Dec]
makeStore name = do
dt <- reifyDataType name
let preds = map (storePred . VarT) (dtTvs dt)
argTy = appsT (ConT name) (map VarT (dtTvs dt))
(:[]) <$> deriveStore preds argTy (dtCons dt)
deriveStore :: Cxt -> Type -> [DataCon] -> Q Dec
deriveStore preds headTy cons0 =
makeStoreInstance preds headTy
<$> sizeExpr
<*> peekExpr
<*> pokeExpr
where
cons :: [(Name, [(Name, Type)])]
cons =
[ ( dcName dc
, [ (mkName ("c" ++ show ixc ++ "f" ++ show ixf), ty)
| ixf <- ints
| (_, ty) <- dcFields dc
]
)
| ixc <- ints
| dc <- cons0
]
(tagType, _, tagSize) =
fromMaybe (error "Too many constructors") $
find (\(_, maxN, _) -> maxN >= length cons) tagTypes
tagTypes :: [(Name, Int, Int)]
tagTypes =
[ ('(), 1, 0)
, (''Word8, fromIntegral (maxBound :: Word8), 1)
, (''Word16, fromIntegral (maxBound :: Word16), 2)
, (''Word32, fromIntegral (maxBound :: Word32), 4)
, (''Word64, fromIntegral (maxBound :: Word64), 8)
]
fName ix = mkName ("f" ++ show ix)
ints = [0..] :: [Int]
fNames = map fName ints
sizeNames = zipWith (\_ -> mkName . ("sz" ++) . show) cons ints
tagName = mkName "tag"
valName = mkName "val"
sizeExpr
| length cons <= 62 =
caseE (tupE (concatMap (map sizeAtType . snd) cons))
(case cons of
[] -> [matchConstSize]
[c] | null (snd c) -> [matchConstSize]
_ -> [matchConstSize, matchVarSize])
| otherwise = varSizeExpr
where
sizeAtType :: (Name, Type) -> ExpQ
sizeAtType (_, ty) = [| size :: Size $(return ty) |]
matchConstSize :: MatchQ
matchConstSize = do
let sz0 = VarE (mkName "sz0")
sizeDecls =
if null sizeNames
then [valD (varP (mkName "sz0")) (normalB [| 0 |]) []]
else zipWith constSizeDec sizeNames cons
sameSizeExpr <-
case sizeNames of
(_ : tailSizeNames) ->
foldl (\l r -> [| $(l) && $(r) |]) [| True |] $
map (\szn -> [| $(return sz0) == $(varE szn) |]) tailSizeNames
[] -> [| True |]
result <- [| ConstSize (tagSize + $(return sz0)) |]
match (tupP (map (\(n, _) -> conP 'ConstSize [varP n])
(concatMap snd cons)))
(guardedB [return (NormalG sameSizeExpr, result)])
sizeDecls
constSizeDec :: Name -> (Name, [(Name, Type)]) -> DecQ
constSizeDec szn (_, []) =
valD (varP szn) (normalB [| 0 |]) []
constSizeDec szn (_, fields) =
valD (varP szn) body []
where
body = normalB $
foldl1 (\l r -> [| $(l) + $(r) |]) $
map (\(sizeName, _) -> varE sizeName) fields
matchVarSize :: MatchQ
matchVarSize = do
match (tupP (map (\(n, _) -> varP n) (concatMap snd cons)))
(normalB varSizeExpr)
[]
varSizeExpr :: ExpQ
varSizeExpr =
[| VarSize $ \x -> tagSize + $(caseE [| x |] (map matchVar cons)) |]
matchVar :: (Name, [(Name, Type)]) -> MatchQ
matchVar (cname, []) =
match (conP cname []) (normalB [| 0 |]) []
matchVar (cname, fields) =
match (conP cname (zipWith (\_ fn -> varP fn) fields fNames))
body
[]
where
body = normalB $
foldl1 (\l r -> [| $(l) + $(r) |])
(zipWith (\(sizeName, _) fn -> [| getSizeWith $(varE sizeName) $(varE fn) |])
fields
fNames)
peekExpr = case cons of
[] -> [| error ("Attempting to peek type with no constructors (" ++ $(lift (show headTy)) ++ ")") |]
[con] -> peekCon con
_ -> doE
[ bindS (varP tagName) [| peek |]
, noBindS (caseE (sigE (varE tagName) (conT tagType))
(map peekMatch (zip [0..] cons) ++ [peekErr]))
]
peekMatch (ix, con) = match (litP (IntegerL ix)) (normalB (peekCon con)) []
peekErr = match wildP (normalB
[| peekException $ T.pack $ "Found invalid tag while peeking (" ++ $(lift (show headTy)) ++ ")" |]) []
peekCon (cname, fields) =
case fields of
[] -> [| pure $(conE cname) |]
_ -> doE $
map (\(fn, _) -> bindS (varP fn) [| peek |]) fields ++
[noBindS $ appE (varE 'return) $ appsE $ conE cname : map (\(fn, _) -> varE fn) fields]
pokeExpr = lamE [varP valName] $ caseE (varE valName) $ zipWith pokeCon [0..] cons
pokeCon :: Int -> (Name, [(Name, Type)]) -> MatchQ
pokeCon ix (cname, fields) =
match (conP cname (map (\(fn, _) -> varP fn) fields)) body []
where
body = normalB $
case cons of
(_:_:_) -> doE (pokeTag ix : map pokeField fields)
_ -> doE (map pokeField fields)
pokeTag ix = noBindS [| poke (ix :: $(conT tagType)) |]
pokeField (fn, _) = noBindS [| poke $(varE fn) |]
deriveTupleStoreInstance :: Int -> Dec
deriveTupleStoreInstance n =
deriveGenericInstance (map storePred tvs)
(foldl1 AppT (TupleT n : tvs))
where
tvs = take n (map (VarT . mkName . (:[])) ['a'..'z'])
deriveGenericInstance :: Cxt -> Type -> Dec
deriveGenericInstance cs ty = plainInstanceD cs (AppT (ConT ''Store) ty) []
deriveManyStoreFromStorable :: (Type -> Bool) -> Q [Dec]
deriveManyStoreFromStorable p = do
storables <- postprocess . instancesMap <$> getInstances ''Storable
stores <- postprocess . instancesMap <$> getInstances ''Store
return $ M.elems $ flip M.mapMaybe (storables `M.difference` stores) $
\(TypeclassInstance cs ty _) ->
let argTy = head (tail (unAppsT ty))
tyNameLit = LitE (StringL (pprint ty)) in
if p argTy
then Just $ makeStoreInstance cs argTy
(AppE (VarE 'sizeStorableTy) tyNameLit)
(AppE (VarE 'peekStorableTy) tyNameLit)
(VarE 'pokeStorable)
else Nothing
deriveManyStorePrimVector :: Q [Dec]
deriveManyStorePrimVector = do
prims <- postprocess . instancesMap <$> getInstances ''PV.Prim
stores <- postprocess . instancesMap <$> getInstances ''Store
let primInsts =
M.mapKeys (map (AppT (ConT ''PV.Vector))) prims
`M.difference`
stores
forM (M.toList primInsts) $ \primInst -> case primInst of
([_], TypeclassInstance cs ty _) -> do
let argTy = head (tail (unAppsT ty))
sizeExpr <- [|
VarSize $ \x ->
I# $(primSizeOfExpr (ConT ''Int)) +
I# $(primSizeOfExpr argTy) * PV.length x
|]
peekExpr <- [| do
len <- peek
let sz = I# $(primSizeOfExpr argTy)
array <- peekToByteArray $(lift ("Primitive Vector (" ++ pprint argTy ++ ")"))
(len * sz)
return (PV.Vector 0 len array)
|]
pokeExpr <- [| \(PV.Vector offset len (ByteArray array)) -> do
let sz = I# $(primSizeOfExpr argTy)
poke len
pokeFromByteArray array (offset * sz) (len * sz)
|]
return $ makeStoreInstance cs (AppT (ConT ''PV.Vector) argTy) sizeExpr peekExpr pokeExpr
_ -> fail "Invariant violated in derivemanyStorePrimVector"
primSizeOfExpr :: Type -> ExpQ
primSizeOfExpr ty = [| $(varE 'sizeOf#) (error "sizeOf# evaluated its argument" :: $(return ty)) |]
deriveManyStoreUnboxVector :: Q [Dec]
deriveManyStoreUnboxVector = do
unboxes <- getUnboxInfo
stores <- postprocess . instancesMap <$> getInstances ''Store
let unboxInsts =
M.fromList (map (\(preds, ty, cons) -> ([AppT (ConT ''UV.Vector) ty], (preds, cons))) unboxes)
`M.difference`
stores
forM (M.toList unboxInsts) $ \case
([ty], (preds, cons)) -> do
let extraPreds = map (storePred . AppT (ConT ''UV.Vector)) $ listify isVarT ty
deriveStore (nub (preds ++ extraPreds)) ty cons
_ -> fail "impossible case in deriveManyStoreUnboxVector"
getUnboxInfo :: Q [(Cxt, Type, [DataCon])]
getUnboxInfo = do
FamilyI _ insts <- reify ''UV.Vector
return (map (everywhere (id `extT` dequalVarT) . go) insts)
where
#if MIN_VERSION_template_haskell(2,11,0)
go (NewtypeInstD preds _ [ty] _ con _) = (preds, ty, conToDataCons con)
go (DataInstD preds _ [ty] _ cons _) = (preds, ty, concatMap conToDataCons cons)
#else
go (NewtypeInstD preds _ [ty] con _) = (preds, ty, conToDataCons con)
go (DataInstD preds _ [ty] cons _) = (preds, ty, concatMap conToDataCons cons)
#endif
go x = error ("Unexpected result from reifying Unboxed Vector instances: " ++ pprint x)
dequalVarT (VarT n) = VarT (dequalify n)
dequalVarT ty = ty
postprocess :: M.Map [Type] [a] -> M.Map [Type] a
postprocess =
M.mapMaybeWithKey $ \tys xs ->
case (tys, xs) of
([_ty], [x]) -> Just x
_ -> Nothing
makeStoreInstance :: Cxt -> Type -> Exp -> Exp -> Exp -> Dec
makeStoreInstance cs ty sizeExpr peekExpr pokeExpr =
plainInstanceD
cs
(AppT (ConT ''Store) ty)
[ ValD (VarP 'size) (NormalB sizeExpr) []
, ValD (VarP 'peek) (NormalB peekExpr) []
, ValD (VarP 'poke) (NormalB pokeExpr) []
]
getAllInstanceTypes :: Name -> Q [[Type]]
getAllInstanceTypes n =
map (\(TypeclassInstance _ ty _) -> drop 1 (unAppsT ty)) <$>
getInstances n
getAllInstanceTypes1 :: Name -> Q [Type]
getAllInstanceTypes1 n =
fmap (fmap (fromMaybe (error "getAllMonoInstances1 expected only one type argument") . headMay))
(getAllInstanceTypes n)
isMonoType :: Type -> Bool
isMonoType = null . listify isVarT
isVarT :: Type -> Bool
isVarT VarT{} = True
isVarT _ = False
instancesMap :: [TypeclassInstance] -> M.Map [Type] [TypeclassInstance]
instancesMap =
M.fromListWith (++) .
map (\ti -> (map getTyHead (instanceArgTypes ti), [ti]))
instanceArgTypes :: TypeclassInstance -> [Type]
instanceArgTypes (TypeclassInstance _ ty _) = drop 1 (unAppsT ty)
getTyHead :: Type -> Type
getTyHead (SigT x _) = getTyHead x
getTyHead (ForallT _ _ x) = getTyHead x
getTyHead (AppT l _) = getTyHead l
getTyHead x = x
storePred :: Type -> Pred
storePred ty =
#if MIN_VERSION_template_haskell(2,10,0)
AppT (ConT ''Store) ty
#else
ClassP ''Store [ty]
#endif