{-# LANGUAGE CPP, BangPatterns, DoAndIfThenElse, RecordWildCards #-}
{-# LANGUAGE DeriveDataTypeable, DeriveGeneric #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
module Database.PostgreSQL.Simple.Internal where
import Control.Applicative
import Control.Exception
import Control.Concurrent.MVar
import Control.Monad(MonadPlus(..))
import Data.ByteString(ByteString)
import qualified Data.ByteString as B
import qualified Data.ByteString.Char8 as B8
import Data.ByteString.Builder ( Builder, byteString )
import Data.Char (ord)
import Data.Int (Int64)
import qualified Data.IntMap as IntMap
import Data.IORef
import Data.Maybe(fromMaybe)
import Data.Monoid
import Data.String
import qualified Data.Text as T
import qualified Data.Text.Encoding as TE
import Data.Typeable
import Data.Word
import Database.PostgreSQL.LibPQ(Oid(..))
import qualified Database.PostgreSQL.LibPQ as PQ
import Database.PostgreSQL.LibPQ(ExecStatus(..))
import Database.PostgreSQL.Simple.Compat ( toByteString )
import Database.PostgreSQL.Simple.Ok
import Database.PostgreSQL.Simple.ToField (Action(..), inQuotes)
import Database.PostgreSQL.Simple.Types (Query(..))
import Database.PostgreSQL.Simple.TypeInfo.Types(TypeInfo)
import Control.Monad.Trans.State.Strict
import Control.Monad.Trans.Reader
import Control.Monad.Trans.Class
import GHC.Generics
import GHC.IO.Exception
#if !defined(mingw32_HOST_OS)
import Control.Concurrent(threadWaitRead, threadWaitWrite)
#endif
data Field = Field {
result :: !PQ.Result
, column :: {-# UNPACK #-} !PQ.Column
, typeOid :: {-# UNPACK #-} !PQ.Oid
}
type TypeInfoCache = IntMap.IntMap TypeInfo
data Connection = Connection {
connectionHandle :: {-# UNPACK #-} !(MVar PQ.Connection)
, connectionObjects :: {-# UNPACK #-} !(MVar TypeInfoCache)
, connectionTempNameCounter :: {-# UNPACK #-} !(IORef Int64)
} deriving (Typeable)
instance Eq Connection where
x == y = connectionHandle x == connectionHandle y
data SqlError = SqlError {
sqlState :: ByteString
, sqlExecStatus :: ExecStatus
, sqlErrorMsg :: ByteString
, sqlErrorDetail :: ByteString
, sqlErrorHint :: ByteString
} deriving (Eq, Show, Typeable)
fatalError :: ByteString -> SqlError
fatalError msg = SqlError "" FatalError msg "" ""
instance Exception SqlError
data QueryError = QueryError {
qeMessage :: String
, qeQuery :: Query
} deriving (Eq, Show, Typeable)
instance Exception QueryError
data FormatError = FormatError {
fmtMessage :: String
, fmtQuery :: Query
, fmtParams :: [ByteString]
} deriving (Eq, Show, Typeable)
instance Exception FormatError
data ConnectInfo = ConnectInfo {
connectHost :: String
, connectPort :: Word16
, connectUser :: String
, connectPassword :: String
, connectDatabase :: String
} deriving (Generic,Eq,Read,Show,Typeable)
defaultConnectInfo :: ConnectInfo
defaultConnectInfo = ConnectInfo {
connectHost = "127.0.0.1"
, connectPort = 5432
, connectUser = "postgres"
, connectPassword = ""
, connectDatabase = ""
}
connect :: ConnectInfo -> IO Connection
connect = connectPostgreSQL . postgreSQLConnectionString
connectPostgreSQL :: ByteString -> IO Connection
connectPostgreSQL connstr = do
conn <- connectdb connstr
stat <- PQ.status conn
case stat of
PQ.ConnectionOk -> do
connectionHandle <- newMVar conn
connectionObjects <- newMVar (IntMap.empty)
connectionTempNameCounter <- newIORef 0
let wconn = Connection{..}
version <- PQ.serverVersion conn
let settings
| version < 80200 = "SET datestyle TO ISO;SET client_encoding TO UTF8"
| otherwise = "SET datestyle TO ISO;SET client_encoding TO UTF8;SET standard_conforming_strings TO on"
_ <- execute_ wconn settings
return wconn
_ -> do
msg <- maybe "connectPostgreSQL error" id <$> PQ.errorMessage conn
throwIO $ fatalError msg
connectdb :: ByteString -> IO PQ.Connection
#if defined(mingw32_HOST_OS)
connectdb = PQ.connectdb
#else
connectdb conninfo = do
conn <- PQ.connectStart conninfo
loop conn
where
funcName = "Database.PostgreSQL.Simple.connectPostgreSQL"
loop conn = do
status <- PQ.connectPoll conn
case status of
PQ.PollingFailed -> throwLibPQError conn "connection failed"
PQ.PollingReading -> do
mfd <- PQ.socket conn
case mfd of
Nothing -> throwIO $! fdError funcName
Just fd -> do
threadWaitRead fd
loop conn
PQ.PollingWriting -> do
mfd <- PQ.socket conn
case mfd of
Nothing -> throwIO $! fdError funcName
Just fd -> do
threadWaitWrite fd
loop conn
PQ.PollingOk -> return conn
#endif
postgreSQLConnectionString :: ConnectInfo -> ByteString
postgreSQLConnectionString connectInfo = fromString connstr
where
connstr = str "host=" connectHost
$ num "port=" connectPort
$ str "user=" connectUser
$ str "password=" connectPassword
$ str "dbname=" connectDatabase
$ []
str name field
| null value = id
| otherwise = showString name . quote value . space
where value = field connectInfo
num name field
| value <= 0 = id
| otherwise = showString name . shows value . space
where value = field connectInfo
quote str rest = '\'' : foldr delta ('\'' : rest) str
where
delta c cs = case c of
'\\' -> '\\' : '\\' : cs
'\'' -> '\\' : '\'' : cs
_ -> c : cs
space [] = []
space xs = ' ':xs
oid2int :: Oid -> Int
oid2int (Oid x) = fromIntegral x
{-# INLINE oid2int #-}
exec :: Connection
-> ByteString
-> IO PQ.Result
#if defined(mingw32_HOST_OS)
exec conn sql =
withConnection conn $ \h -> do
mres <- PQ.exec h sql
case mres of
Nothing -> throwLibPQError h "PQexec returned no results"
Just res -> return res
#else
exec conn sql =
withConnection conn $ \h -> do
success <- PQ.sendQuery h sql
if success
then awaitResult h Nothing
else throwLibPQError h "PQsendQuery failed"
where
awaitResult h mres = do
mfd <- PQ.socket h
case mfd of
Nothing -> throwIO $! fdError "Database.PostgreSQL.Simple.Internal.exec"
Just fd -> do
threadWaitRead fd
_ <- PQ.consumeInput h
getResult h mres
getResult h mres = do
isBusy <- PQ.isBusy h
if isBusy
then awaitResult h mres
else do
mres' <- PQ.getResult h
case mres' of
Nothing -> case mres of
Nothing -> throwLibPQError h "PQgetResult returned no results"
Just res -> return res
Just res -> do
status <- PQ.resultStatus res
case status of
PQ.EmptyQuery -> getResult h mres'
PQ.CommandOk -> getResult h mres'
PQ.TuplesOk -> getResult h mres'
PQ.CopyOut -> return res
PQ.CopyIn -> return res
PQ.BadResponse -> getResult h mres'
PQ.NonfatalError -> getResult h mres'
PQ.FatalError -> getResult h mres'
#endif
execute_ :: Connection -> Query -> IO Int64
execute_ conn q@(Query stmt) = do
result <- exec conn stmt
finishExecute conn q result
finishExecute :: Connection -> Query -> PQ.Result -> IO Int64
finishExecute _conn q result = do
status <- PQ.resultStatus result
case status of
PQ.EmptyQuery -> throwIO $ QueryError "execute: Empty query" q
PQ.CommandOk -> do
ncols <- PQ.nfields result
if ncols /= 0
then throwIO $ QueryError ("execute resulted in " ++ show ncols ++
"-column result") q
else do
nstr <- PQ.cmdTuples result
return $ case nstr of
Nothing -> 0
Just str -> toInteger str
PQ.TuplesOk -> do
ncols <- PQ.nfields result
throwIO $ QueryError ("execute resulted in " ++ show ncols ++
"-column result") q
PQ.CopyOut ->
throwIO $ QueryError "execute: COPY TO is not supported" q
PQ.CopyIn ->
throwIO $ QueryError "execute: COPY FROM is not supported" q
PQ.BadResponse -> throwResultError "execute" result status
PQ.NonfatalError -> throwResultError "execute" result status
PQ.FatalError -> throwResultError "execute" result status
where
toInteger str = B8.foldl' delta 0 str
where
delta acc c =
if '0' <= c && c <= '9'
then 10 * acc + fromIntegral (ord c - ord '0')
else error ("finishExecute: not an int: " ++ B8.unpack str)
throwResultError :: ByteString -> PQ.Result -> PQ.ExecStatus -> IO a
throwResultError _ result status = do
errormsg <- fromMaybe "" <$>
PQ.resultErrorField result PQ.DiagMessagePrimary
detail <- fromMaybe "" <$>
PQ.resultErrorField result PQ.DiagMessageDetail
hint <- fromMaybe "" <$>
PQ.resultErrorField result PQ.DiagMessageHint
state <- maybe "" id <$> PQ.resultErrorField result PQ.DiagSqlstate
throwIO $ SqlError { sqlState = state
, sqlExecStatus = status
, sqlErrorMsg = errormsg
, sqlErrorDetail = detail
, sqlErrorHint = hint }
disconnectedError :: SqlError
disconnectedError = fatalError "connection disconnected"
withConnection :: Connection -> (PQ.Connection -> IO a) -> IO a
withConnection Connection{..} m = do
withMVar connectionHandle $ \conn -> do
if PQ.isNullConnection conn
then throwIO disconnectedError
else m conn
close :: Connection -> IO ()
close Connection{..} =
mask $ \restore -> (do
conn <- takeMVar connectionHandle
restore (PQ.finish conn)
`finally` do
putMVar connectionHandle =<< PQ.newNullConnection
)
newNullConnection :: IO Connection
newNullConnection = do
connectionHandle <- newMVar =<< PQ.newNullConnection
connectionObjects <- newMVar IntMap.empty
connectionTempNameCounter <- newIORef 0
return Connection{..}
data Row = Row {
row :: {-# UNPACK #-} !PQ.Row
, rowresult :: !PQ.Result
}
newtype RowParser a = RP { unRP :: ReaderT Row (StateT PQ.Column Conversion) a }
deriving ( Functor, Applicative, Alternative, Monad )
liftRowParser :: IO a -> RowParser a
liftRowParser = RP . lift . lift . liftConversion
newtype Conversion a = Conversion { runConversion :: Connection -> IO (Ok a) }
liftConversion :: IO a -> Conversion a
liftConversion m = Conversion (\_ -> Ok <$> m)
instance Functor Conversion where
fmap f m = Conversion $ \conn -> (fmap . fmap) f (runConversion m conn)
instance Applicative Conversion where
pure a = Conversion $ \_conn -> pure (pure a)
mf <*> ma = Conversion $ \conn -> do
okf <- runConversion mf conn
case okf of
Ok f -> (fmap . fmap) f (runConversion ma conn)
Errors errs -> return (Errors errs)
instance Alternative Conversion where
empty = Conversion $ \_conn -> pure empty
ma <|> mb = Conversion $ \conn -> do
oka <- runConversion ma conn
case oka of
Ok _ -> return oka
Errors _ -> (oka <|>) <$> runConversion mb conn
instance Monad Conversion where
#if !(MIN_VERSION_base(4,8,0))
return = pure
#endif
m >>= f = Conversion $ \conn -> do
oka <- runConversion m conn
case oka of
Ok a -> runConversion (f a) conn
Errors err -> return (Errors err)
instance MonadPlus Conversion where
mzero = empty
mplus = (<|>)
conversionMap :: (Ok a -> Ok b) -> Conversion a -> Conversion b
conversionMap f m = Conversion $ \conn -> f <$> runConversion m conn
conversionError :: Exception err => err -> Conversion a
conversionError err = Conversion $ \_ -> return (Errors [toException err])
newTempName :: Connection -> IO Query
newTempName Connection{..} = do
!n <- atomicModifyIORef connectionTempNameCounter
(\n -> let !n' = n+1 in (n', n'))
return $! Query $ B8.pack $ "temp" ++ show n
fdError :: ByteString -> IOError
fdError funcName = IOError {
ioe_handle = Nothing,
ioe_type = ResourceVanished,
ioe_location = B8.unpack funcName,
ioe_description = "failed to fetch file descriptor",
ioe_errno = Nothing,
ioe_filename = Nothing
}
libPQError :: ByteString -> IOError
libPQError desc = IOError {
ioe_handle = Nothing,
ioe_type = OtherError,
ioe_location = "libpq",
ioe_description = B8.unpack desc,
ioe_errno = Nothing,
ioe_filename = Nothing
}
throwLibPQError :: PQ.Connection -> ByteString -> IO a
throwLibPQError conn default_desc = do
msg <- maybe default_desc id <$> PQ.errorMessage conn
throwIO $! libPQError msg
fmtError :: String -> Query -> [Action] -> a
fmtError msg q xs = throw FormatError {
fmtMessage = msg
, fmtQuery = q
, fmtParams = map twiddle xs
}
where twiddle (Plain b) = toByteString b
twiddle (Escape s) = s
twiddle (EscapeByteA s) = s
twiddle (EscapeIdentifier s) = s
twiddle (Many ys) = B.concat (map twiddle ys)
fmtErrorBs :: Query -> [Action] -> ByteString -> a
fmtErrorBs q xs msg = fmtError (T.unpack $ TE.decodeUtf8 msg) q xs
quote :: Query -> [Action] -> Either ByteString ByteString -> Builder
quote q xs = either (fmtErrorBs q xs) (inQuotes . byteString)
buildAction :: Connection
-> Query
-> [Action]
-> Action
-> IO Builder
buildAction _ _ _ (Plain b) = pure b
buildAction conn q xs (Escape s) = quote q xs <$> escapeStringConn conn s
buildAction conn q xs (EscapeByteA s) = quote q xs <$> escapeByteaConn conn s
buildAction conn q xs (EscapeIdentifier s) =
either (fmtErrorBs q xs) byteString <$> escapeIdentifier conn s
buildAction conn q xs (Many ys) =
mconcat <$> mapM (buildAction conn q xs) ys
checkError :: PQ.Connection -> Maybe a -> IO (Either ByteString a)
checkError _ (Just x) = return $ Right x
checkError c Nothing = Left . maybe "" id <$> PQ.errorMessage c
escapeWrap :: (PQ.Connection -> ByteString -> IO (Maybe ByteString))
-> Connection
-> ByteString
-> IO (Either ByteString ByteString)
escapeWrap f conn s =
withConnection conn $ \c ->
f c s >>= checkError c
escapeStringConn :: Connection -> ByteString -> IO (Either ByteString ByteString)
escapeStringConn = escapeWrap PQ.escapeStringConn
escapeIdentifier :: Connection -> ByteString -> IO (Either ByteString ByteString)
escapeIdentifier = escapeWrap PQ.escapeIdentifier
escapeByteaConn :: Connection -> ByteString -> IO (Either ByteString ByteString)
escapeByteaConn = escapeWrap PQ.escapeByteaConn