{-# LANGUAGE TypeFamilies, FlexibleInstances, ScopedTypeVariables #-}
{-# LANGUAGE MultiParamTypeClasses, FlexibleContexts #-}
{-# LANGUAGE UndecidableInstances #-}
module Database.Selda.Prepared (Preparable, Prepare, prepared) where
import Database.Selda.Backend.Internal
import Database.Selda.Column
import Database.Selda.Compile
import Database.Selda.Query.Type
import Database.Selda.SQL (param, paramType)
import Control.Exception
import Control.Monad.IO.Class
import qualified Data.IntMap as M
import Data.IORef
import Data.Proxy
import Data.Text (Text)
import Data.Typeable
import System.IO.Unsafe
data Placeholder = Placeholder Int
deriving Show
instance Exception Placeholder
firstParamIx :: Int
firstParamIx = 0
type family ResultT f where
ResultT (a -> b) = ResultT b
ResultT (m a) = a
type family Equiv q f where
Equiv (Col s a -> q) (a -> f) = Equiv q f
Equiv (Query s a) (m [b]) = (Res a ~ b, Backend m ~ s)
type CompResult = (Text, [Either Int Param], [SqlTypeRep])
class Preparable q where
mkQuery :: MonadSelda m
=> Int
-> q
-> [SqlTypeRep]
-> m CompResult
class Prepare q f where
mkFun :: Preparable q
=> IORef (Maybe (BackendID, CompResult))
-> StmtID
-> q
-> [Param]
-> f
instance (SqlType a, Prepare q b) => Prepare q (a -> b) where
mkFun ref sid qry ps x = mkFun ref sid qry (param x : ps)
instance (Typeable a, MonadSelda m, a ~ Res (ResultT q), Result (ResultT q)) =>
Prepare q (m [a]) where
mkFun ref (StmtID sid) qry arguments = withConnection $ \conn -> do
let backend = connBackend conn
args = reverse arguments
stmts <- liftIO $ readIORef (connStmts conn)
case M.lookup sid stmts of
Just stm -> do
liftIO $ runQuery conn stm args
_ -> do
compiled <- liftIO $ readIORef ref
(q, params, reps) <- case compiled of
Just (bid, comp) | bid == backendId backend -> do
return comp
_ -> do
comp <- mkQuery firstParamIx qry []
liftIO $ writeIORef ref (Just (backendId backend, comp))
return comp
liftIO $ mask $ \restore -> do
hdl <- prepareStmt backend (StmtID sid) reps q
let stm = SeldaStmt
{ stmtHandle = hdl
, stmtParams = params
, stmtText = q
}
atomicModifyIORef' (connStmts conn) $ \m -> (M.insert sid stm m, ())
restore $ runQuery conn stm args
where
runQuery conn stm args = do
let ps = replaceParams (stmtParams stm) args
hdl = stmtHandle stm
res <- runPrepared (connBackend conn) hdl ps
return $ map (buildResult (Proxy :: Proxy (ResultT q))) (snd res)
instance (SqlType a, Preparable b) => Preparable (Col s a -> b) where
mkQuery n f ts = mkQuery (n+1) (f x) (t : ts)
where
t = sqlType (Proxy :: Proxy a)
x = One $ Lit $ LCustom t (throw (Placeholder n) :: Lit a)
instance Result a => Preparable (Query s a) where
mkQuery _ q types = withBackend $ \b -> do
case compileWith (ppConfig b) q of
(q', ps) -> do
(ps', types') <- liftIO $ inspectParams (reverse types) ps
return (q', ps', types')
{-# NOINLINE prepared #-}
prepared :: (Preparable q, Prepare q f, Equiv q f) => q -> f
prepared q = unsafePerformIO $ do
ref <- newIORef Nothing
sid <- freshStmtId
return $ mkFun ref sid q []
replaceParams :: [Either Int Param] -> [Param] -> [Param]
replaceParams params = map fromRight . go firstParamIx params
where
go n ps (x:xs) = go (n+1) (map (subst n x) ps) xs
go _ ps _ = ps
subst n x (Left n') | n == n' = Right x
subst _ _ old = old
fromRight (Right x) = x
fromRight _ = error "BUG: query parameter not substituted!"
inspectParams :: [SqlTypeRep] -> [Param] -> IO ([Either Int Param], [SqlTypeRep])
inspectParams ts (x:xs) = do
res <- try $ pure $! forceParam x
let (x', t) = case res of
Right p -> (Right p, paramType p)
Left (Placeholder ix) -> (Left ix, ts !! ix)
(xs', ts') <- inspectParams ts xs
return (x' : xs', t : ts')
inspectParams _ [] = do
return ([], [])
forceParam :: Param -> Param
forceParam p@(Param (LCustom _ x)) | x `seq` True = p
forceParam p = p