#include "MachDeps.h"
module Database.PostgreSQL.Escape (
fmtSql, quoteIdent
, buildSql, buildSqlFromActions
, buildAction, buildLiteral, buildByteA, buildIdent
) where
import Blaze.ByteString.Builder
import Blaze.ByteString.Builder.Char8 (fromChar)
import Blaze.ByteString.Builder.Internal
import qualified Data.ByteString as S
import qualified Data.ByteString.Internal as S
import qualified Data.ByteString.Unsafe as S
import Data.Monoid
import Database.PostgreSQL.Simple
import Database.PostgreSQL.Simple.ToField
import Database.PostgreSQL.Simple.ToRow
import Database.PostgreSQL.Simple.Types
import Foreign.Marshal.Alloc (mallocBytes)
import Foreign.Storable (pokeByteOff)
import Foreign.Ptr
import GHC.Prim (Addr#, and#, geAddr#, geWord#, Int#, int2Word#
, minusAddr#, ord# , plusAddr#, readWord8OffAddr#
, State# , uncheckedShiftRL#, word2Int#, writeWord8OffAddr#
, Word#)
import GHC.Ptr (Ptr(Ptr))
import GHC.Types (Char(C#), Int(I#), IO(IO))
import GHC.Word (Word8(W8#))
import System.IO.Unsafe (unsafeDupablePerformIO)
#if __GLASGOW_HASKELL__ >= 707
cmpres :: Int# -> Bool
cmpres 0# = False
cmpres _ = True
#else /* __GLASGOW_HASKELL__ < 707 */
cmpres :: Bool -> Bool
cmpres b = b
#define cmpres(b) b
#endif /* __GLASGOW_HASKELL__ < 707 */
c2b :: Char -> Word8
c2b (C# i) = W8# (int2Word# (ord# i))
c2b# :: Char -> Word#
c2b# (C# i) = int2Word# (ord# i)
fastFindIndex :: (Word# -> Bool) -> S.ByteString -> Maybe Int
fastFindIndex test bs =
S.inlinePerformIO $ S.unsafeUseAsCStringLen bs $ \(Ptr bsp0, I# bsl0) -> do
let bse = bsp0 `plusAddr#` bsl0
check bsp = IO $ \rw -> case readWord8OffAddr# bsp 0# rw of
(# rw1, w #) -> (# rw1, test w #)
go bsp | cmpres(bsp `geAddr#` bse) = return Nothing
| otherwise = do
match <- check bsp
if match
then return $ Just $ I# (bsp `minusAddr#` bsp0)
else go (bsp `plusAddr#` 1#)
go bsp0
fastBreak :: (Word# -> Bool) -> S.ByteString -> (S.ByteString, S.ByteString)
fastBreak test bs
| Just n <- fastFindIndex test bs = (S.unsafeTake n bs, S.unsafeDrop n bs)
| otherwise = (bs, S.empty)
quoter :: S.ByteString -> S.ByteString -> (Word# -> Bool)
-> (Word8 -> Builder) -> S.ByteString -> Builder
quoter start end escPred escFn bs0 =
mconcat [copyByteString start, escaped bs0, copyByteString end]
where escaped bs = case fastBreak escPred bs of
(h, t) | S.null t -> fromByteString h
| otherwise -> fromByteString h <>
escFn (S.unsafeHead t) <>
escaped (S.unsafeTail t)
uBuildIdent :: S.ByteString -> Builder
uBuildIdent ident = quoter " U&\"" "\"" isSpecial esc ident
where isSpecial 34## = True
isSpecial 63## = True
isSpecial 92## = True
isSpecial _ = False
esc c = copyByteString $ case () of
_ | c == c2b '"' -> "\"\""
| c == c2b '?' -> "\\003f"
| c == c2b '\\' -> "\\\\"
| otherwise -> error "uquoteIdent"
buildIdent :: S.ByteString -> Builder
buildIdent ident
| Just _ <- fastFindIndex isQuestionmark ident = uBuildIdent ident
| otherwise = quoter "\"" "\"" isDQuote (const $ copyByteString "\"\"") ident
where isQuestionmark 63## = True
isQuestionmark 0## = error "quoteIdent: illegal NUL character"
isQuestionmark _ = False
isDQuote 34## = True
isDQuote _ = False
quoteIdent :: S.ByteString -> S.ByteString
quoteIdent = toByteString . buildIdent
hexNibblesPtr :: Ptr Word8
hexNibblesPtr = unsafeDupablePerformIO $ do
ptr <- mallocBytes 16
sequence_ $ zipWith (\o v -> pokeByteOff ptr o $ c2b v)
[0..] (['0'..'9'] ++ ['a'..'f'])
return ptr
uncheckedWriteNibbles# :: Addr# -> Word# -> State# d -> State# d
uncheckedWriteNibbles# p w rw0 =
case (# word2Int# (w `uncheckedShiftRL#` 4# )
, word2Int# (w `and#` 0xf## ) #) of { (# h, l #) ->
case readWord8OffAddr# nibbles h rw0 of { (# rw1, hascii #) ->
case writeWord8OffAddr# p 0# hascii rw1 of { rw2 ->
case readWord8OffAddr# nibbles l rw2 of { (# rw3, lascii #) ->
writeWord8OffAddr# p 1# lascii rw3 }}}}
where !(Ptr nibbles) = hexNibblesPtr
hexCharEscBuilder :: Word8 -> Builder
hexCharEscBuilder (W8# w) = fromWrite $ exactWrite 4 $ \(Ptr p) -> IO $ \rw0 ->
(# uncheckedWriteNibbles# (p `plusAddr#` 2#) w
(writeWord8OffAddr# p 1# (c2b# 'x')
(writeWord8OffAddr# p 0# (c2b# '\\') rw0))
, () #)
buildLiteral :: S.ByteString -> Builder
buildLiteral = quoter " E'" "'" isSpecial esc
where isSpecial 39## = True
isSpecial 63## = True
isSpecial 92## = True
isSpecial b = cmpres(b `geWord#` 128##)
esc b | b == c2b '\'' = copyByteString "''"
| b == c2b '\\' = copyByteString "\\\\"
| otherwise = hexCharEscBuilder b
copyByteToNibbles :: Addr# -> Addr# -> IO ()
copyByteToNibbles src dst = IO $ \rw0 ->
case readWord8OffAddr# src 0# rw0 of
(# rw1, w #) -> (# uncheckedWriteNibbles# dst w rw1, () #)
buildByteA :: S.ByteString -> Builder
buildByteA bs = equote $
fromBuildStepCont $ \cont (BufRange (Ptr bb0) (Ptr be0)) ->
S.unsafeUseAsCStringLen bs $ \(Ptr inptr0, I# inlen0) -> do
let ine = plusAddr# inptr0 inlen0
fill oute inp outp
| cmpres(inp `geAddr#` ine) = cont (BufRange (Ptr outp) (Ptr oute))
| cmpres(plusAddr# outp 2# `geAddr#` oute) = return $
bufferFull (2 * (I# (ine `minusAddr#` inp)) + 1) (Ptr outp) $
\(BufRange (Ptr bb) (Ptr be)) -> fill be inp bb
| otherwise = do copyByteToNibbles inp outp
fill oute (inp `plusAddr#` 1#) (outp `plusAddr#` 2#)
fill be0 inptr0 bb0
where equote b = mconcat [fromByteString " E'\\\\x", b, fromChar '\'']
buildAction :: Action -> Builder
buildAction (Plain b) = b
buildAction (Escape bs) = buildLiteral bs
buildAction (EscapeByteA bs) = buildByteA bs
buildAction (EscapeIdentifier bs) = buildIdent bs
buildAction (Many bs) = mconcat $ map buildAction bs
buildSqlFromActions :: Query -> [Action] -> Builder
buildSqlFromActions (Query template) actions =
intercatlate (split template) (map buildAction $ actions)
where intercatlate (t:ts) (p:ps) = t <> p <> intercatlate ts ps
intercatlate [t] [] = t
intercatlate _ _ =
error $ "buildSql: wrong number of parameters for " ++ show template
split s = case S.breakByte (c2b '?') s of
(h,t) | S.null t -> [fromByteString h]
| otherwise -> fromByteString h : split (S.unsafeTail t)
buildSql :: (ToRow p) => Query -> p -> Builder
buildSql q p = buildSqlFromActions q (toRow p)
fmtSql :: (ToRow p) => Query -> p -> Query
fmtSql q p = Query $ toByteString $ buildSql q p