module Database.PostgreSQL.Typed.TH
( getTPGDatabase
, withTPGConnection
, useTPGDatabase
, reloadTPGTypes
, TPGValueInfo(..)
, tpgDescribe
, tpgTypeEncoder
, tpgTypeDecoder
) where
import Control.Applicative ((<$>), (<$), (<|>))
import Control.Concurrent.MVar (MVar, newMVar, takeMVar, putMVar, modifyMVar_)
import Control.Exception (onException, finally)
import Control.Monad (liftM2)
import qualified Data.Foldable as Fold
import qualified Data.IntMap.Lazy as IntMap
import Data.List (find)
import Data.Maybe (isJust, fromMaybe)
import qualified Data.Traversable as Tv
import qualified Language.Haskell.TH as TH
import Network (PortID(UnixSocket, PortNumber), PortNumber)
import System.Environment (lookupEnv)
import System.IO.Unsafe (unsafePerformIO, unsafeInterleaveIO)
import Database.PostgreSQL.Typed.Types
import Database.PostgreSQL.Typed.Protocol
type TPGType = String
getTPGDatabase :: IO PGDatabase
getTPGDatabase = do
user <- fromMaybe "postgres" <$> liftM2 (<|>) (lookupEnv "TPG_USER") (lookupEnv "USER")
db <- fromMaybe user <$> lookupEnv "TPG_DB"
host <- fromMaybe "localhost" <$> lookupEnv "TPG_HOST"
pnum <- maybe (5432 :: PortNumber) ((fromIntegral :: Int -> PortNumber) . read) <$> lookupEnv "TPG_PORT"
port <- maybe (PortNumber pnum) UnixSocket <$> lookupEnv "TPG_SOCK"
pass <- fromMaybe "" <$> lookupEnv "TPG_PASS"
debug <- isJust <$> lookupEnv "TPG_DEBUG"
return $ defaultPGDatabase
{ pgDBHost = host
, pgDBPort = port
, pgDBName = db
, pgDBUser = user
, pgDBPass = pass
, pgDBDebug = debug
}
tpgState :: MVar (PGDatabase, Maybe TPGState)
tpgState = unsafePerformIO $
newMVar (unsafePerformIO getTPGDatabase, Nothing)
data TPGState = TPGState
{ tpgConnection :: PGConnection
, tpgTypes :: IntMap.IntMap TPGType
}
tpgLoadTypes :: TPGState -> IO TPGState
tpgLoadTypes tpg = do
tl <- unsafeInterleaveIO $ pgSimpleQuery (tpgConnection tpg) "SELECT typ.oid, format_type(CASE WHEN typtype = 'd' THEN typbasetype ELSE typ.oid END, -1) FROM pg_catalog.pg_type typ JOIN pg_catalog.pg_namespace nsp ON typnamespace = nsp.oid WHERE nspname <> 'pg_toast' AND nspname <> 'information_schema' ORDER BY typ.oid"
return $ tpg{ tpgTypes = IntMap.fromAscList $ map (\[PGTextValue to, PGTextValue tn] ->
(fromIntegral (pgDecode (PGTypeProxy :: PGTypeName "oid") to :: OID), pgDecode (PGTypeProxy :: PGTypeName "text") tn)) $ Fold.toList $ snd tl
}
tpgInit :: PGConnection -> IO TPGState
tpgInit c = tpgLoadTypes TPGState{ tpgConnection = c, tpgTypes = undefined }
withTPGState :: (TPGState -> IO a) -> IO a
withTPGState f = do
(db, tpg') <- takeMVar tpgState
tpg <- maybe (tpgInit =<< pgConnect db) return tpg'
`onException` putMVar tpgState (db, Nothing)
f tpg `finally` putMVar tpgState (db, Just tpg)
withTPGConnection :: (PGConnection -> IO a) -> IO a
withTPGConnection f = withTPGState (f . tpgConnection)
useTPGDatabase :: PGDatabase -> TH.DecsQ
useTPGDatabase db = TH.runIO $ do
(db', tpg') <- takeMVar tpgState
putMVar tpgState . (,) db =<<
(if db == db'
then Tv.mapM (\t -> do
c <- pgReconnect (tpgConnection t) db
return t{ tpgConnection = c }) tpg'
else Nothing <$ Fold.mapM_ (pgDisconnect . tpgConnection) tpg')
`onException` putMVar tpgState (db, Nothing)
return []
reloadTPGTypes :: TH.DecsQ
reloadTPGTypes = TH.runIO $ [] <$ modifyMVar_ tpgState (\(d, c) -> (,) d <$> Tv.mapM tpgLoadTypes c)
tpgType :: TPGState -> OID -> TPGType
tpgType TPGState{ tpgTypes = types } t =
IntMap.findWithDefault (error $ "Unknown PostgreSQL type: " ++ show t) (fromIntegral t) types
getTPGTypeOID :: Monad m => TPGState -> String -> m OID
getTPGTypeOID TPGState{ tpgTypes = types } t =
maybe (fail $ "Unknown PostgreSQL type: " ++ t ++ "; be sure to use the exact type name from \\dTS") (return . fromIntegral . fst)
$ find ((==) t . snd) $ IntMap.toList types
tpgTypeIsBinary :: TPGType -> TH.Q Bool
tpgTypeIsBinary t =
TH.isInstance ''PGBinaryType [TH.LitT (TH.StrTyLit t)]
data TPGValueInfo = TPGValueInfo
{ tpgValueName :: String
, tpgValueTypeOID :: !OID
, tpgValueType :: TPGType
, tpgValueBinary :: Bool
, tpgValueNullable :: Bool
}
tpgDescribe :: String -> [String] -> Bool -> TH.Q ([TPGValueInfo], [TPGValueInfo])
tpgDescribe sql types nulls = do
(pv, rv) <- TH.runIO $ withTPGState $ \tpg -> do
at <- mapM (getTPGTypeOID tpg) types
(pt, rt) <- pgDescribe (tpgConnection tpg) sql at nulls
return
( map (\o -> TPGValueInfo
{ tpgValueName = ""
, tpgValueTypeOID = o
, tpgValueType = tpgType tpg o
, tpgValueBinary = False
, tpgValueNullable = True
}) pt
, map (\(c, o, n) -> TPGValueInfo
{ tpgValueName = c
, tpgValueTypeOID = o
, tpgValueType = tpgType tpg o
, tpgValueBinary = False
, tpgValueNullable = n
}) rt
)
#ifdef USE_BINARY
liftM2 (,) (fillBin pv) (fillBin rv)
where
fillBin = mapM (\i -> do
b <- tpgTypeIsBinary (tpgValueType i)
return i{ tpgValueBinary = b })
#else
return (pv, rv)
#endif
typeApply :: TPGType -> TH.Name -> TH.Name -> TH.Name -> TH.Exp
typeApply t f e v =
TH.VarE f `TH.AppE` TH.VarE e
`TH.AppE` (TH.ConE 'PGTypeProxy `TH.SigE` (TH.ConT ''PGTypeName `TH.AppT` TH.LitT (TH.StrTyLit t)))
`TH.AppE` TH.VarE v
tpgTypeEncoder :: Bool -> TPGValueInfo -> TH.Name -> TH.Name -> TH.Exp
tpgTypeEncoder lit v = typeApply (tpgValueType v) $ if lit
then 'pgEscapeParameter
else if tpgValueBinary v then 'pgEncodeBinaryParameter else 'pgEncodeParameter
tpgTypeDecoder :: TPGValueInfo -> TH.Name -> TH.Name -> TH.Exp
tpgTypeDecoder v = typeApply (tpgValueType v) $ if tpgValueBinary v
then if tpgValueNullable v then 'pgDecodeBinaryColumn else 'pgDecodeBinaryColumnNotNull
else if tpgValueNullable v then 'pgDecodeColumn else 'pgDecodeColumnNotNull