{-# LANGUAGE BangPatterns #-}

{-|
Module      : Database.MySQL.Protocol.Escape
Description : Pure haskell mysql escape
Copyright   : (c) Winterland, 2016
License     : BSD
Maintainer  : drkoster@qq.com
Stability   : experimental
Portability : PORTABLE

This module provide escape machinery for bytes and text types.

reference: <http://dev.mysql.com/doc/refman/5.7/en/string-literals.html>

    * Escape Sequence	Character Represented by Sequence
    * \0              	An ASCII NUL (X'00') character
    * \'              	A single quote (“'”) character
    * \"              	A double quote (“"”) character
    * \b              	A backspace character
    * \n              	A newline (linefeed) character
    * \r              	A carriage return character
    * \t              	A tab character
    * \Z              	ASCII 26 (Control+Z); see note following the table
    * \\              	A backslash (“\”) character
    * \%              	A “%” character; see note following the table
    * \_              	A “_” character; see note following the table

The @\%@ and @\_@ sequences are used to search for literal instances of @%@ and @_@ in pattern-matching contexts where they would otherwise be interpreted as wildcard characters, so we won't auto escape @%@ or @_@ here.

-}

module Database.MySQL.Protocol.Escape where

import           Data.ByteString          (ByteString)
import qualified Data.ByteString.Internal as B
import           Data.Text                (Text)
import qualified Data.Text.Array          as TA
import qualified Data.Text.Internal       as T
import           Data.Word
import           Foreign.ForeignPtr       (withForeignPtr)
import           Foreign.Ptr              (Ptr, minusPtr, plusPtr)
import           Foreign.Storable         (peek, poke, pokeByteOff)
import           GHC.IO                   (unsafeDupablePerformIO)

escapeText :: Text -> Text
escapeText :: Text -> Text
escapeText (T.Text Array
arr Int
off Int
len)
    | Int
len forall a. Ord a => a -> a -> Bool
<= Int
0  = Text
T.empty
    | Bool
otherwise =
        let (Array
arr', Int
len') =  forall a. (forall s. ST s (MArray s, a)) -> (Array, a)
TA.run2 forall a b. (a -> b) -> a -> b
$ do
                MArray s
marr <- forall s. Int -> ST s (MArray s)
TA.new (Int
len forall a. Num a => a -> a -> a
* Int
2)
                forall {s}.
Array -> Int -> MArray s -> Int -> Int -> ST s (MArray s, Int)
loop Array
arr (Int
off forall a. Num a => a -> a -> a
+ Int
len) MArray s
marr Int
off Int
0
        in Array -> Int -> Int -> Text
T.Text Array
arr' Int
0 Int
len'
  where
    escape :: Word16 -> MArray s -> Int -> ST s ()
escape Word16
c MArray s
marr Int
ix = do
        forall s. MArray s -> Int -> Word16 -> ST s ()
TA.unsafeWrite MArray s
marr Int
ix Word16
92
        forall s. MArray s -> Int -> Word16 -> ST s ()
TA.unsafeWrite MArray s
marr (Int
ixforall a. Num a => a -> a -> a
+Int
1) Word16
c

    loop :: Array -> Int -> MArray s -> Int -> Int -> ST s (MArray s, Int)
loop Array
oarr Int
oend MArray s
marr !Int
ix !Int
ix'
        | Int
ix forall a. Eq a => a -> a -> Bool
== Int
oend = forall (m :: * -> *) a. Monad m => a -> m a
return (MArray s
marr, Int
ix')
        | Bool
otherwise  = do
            let c :: Word16
c = Array -> Int -> Word16
TA.unsafeIndex Array
oarr Int
ix
                go1 :: ST s (MArray s, Int)
go1 = Array -> Int -> MArray s -> Int -> Int -> ST s (MArray s, Int)
loop Array
oarr Int
oend MArray s
marr (Int
ixforall a. Num a => a -> a -> a
+Int
1) (Int
ix'forall a. Num a => a -> a -> a
+Int
1)
                go2 :: ST s (MArray s, Int)
go2 = Array -> Int -> MArray s -> Int -> Int -> ST s (MArray s, Int)
loop Array
oarr Int
oend MArray s
marr (Int
ixforall a. Num a => a -> a -> a
+Int
1) (Int
ix'forall a. Num a => a -> a -> a
+Int
2)
            if  | Word16
c forall a. Ord a => a -> a -> Bool
>= Word16
0xD800 Bool -> Bool -> Bool
&& Word16
c forall a. Ord a => a -> a -> Bool
<= Word16
0xDBFF  -> do let c2 :: Word16
c2 = Array -> Int -> Word16
TA.unsafeIndex Array
oarr (Int
ixforall a. Num a => a -> a -> a
+Int
1)
                                                    forall s. MArray s -> Int -> Word16 -> ST s ()
TA.unsafeWrite MArray s
marr Int
ix' Word16
c
                                                    forall s. MArray s -> Int -> Word16 -> ST s ()
TA.unsafeWrite MArray s
marr (Int
ix'forall a. Num a => a -> a -> a
+Int
1) Word16
c2
                                                    Array -> Int -> MArray s -> Int -> Int -> ST s (MArray s, Int)
loop Array
oarr Int
oend MArray s
marr (Int
ixforall a. Num a => a -> a -> a
+Int
2) (Int
ix'forall a. Num a => a -> a -> a
+Int
2)
                | Word16
c forall a. Eq a => a -> a -> Bool
== Word16
0
                    Bool -> Bool -> Bool
|| Word16
c forall a. Eq a => a -> a -> Bool
== Word16
39
                    Bool -> Bool -> Bool
|| Word16
c forall a. Eq a => a -> a -> Bool
== Word16
34 -> forall {s}. Word16 -> MArray s -> Int -> ST s ()
escape Word16
c   MArray s
marr Int
ix' forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> ST s (MArray s, Int)
go2 -- \0 \' \"
                | Word16
c forall a. Eq a => a -> a -> Bool
== Word16
8       -> forall {s}. Word16 -> MArray s -> Int -> ST s ()
escape Word16
98  MArray s
marr Int
ix' forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> ST s (MArray s, Int)
go2 -- \b
                | Word16
c forall a. Eq a => a -> a -> Bool
== Word16
10      -> forall {s}. Word16 -> MArray s -> Int -> ST s ()
escape Word16
110 MArray s
marr Int
ix' forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> ST s (MArray s, Int)
go2 -- \n
                | Word16
c forall a. Eq a => a -> a -> Bool
== Word16
13      -> forall {s}. Word16 -> MArray s -> Int -> ST s ()
escape Word16
114 MArray s
marr Int
ix' forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> ST s (MArray s, Int)
go2 -- \r
                | Word16
c forall a. Eq a => a -> a -> Bool
== Word16
9       -> forall {s}. Word16 -> MArray s -> Int -> ST s ()
escape Word16
116 MArray s
marr Int
ix' forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> ST s (MArray s, Int)
go2 -- \t
                | Word16
c forall a. Eq a => a -> a -> Bool
== Word16
26      -> forall {s}. Word16 -> MArray s -> Int -> ST s ()
escape Word16
90  MArray s
marr Int
ix' forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> ST s (MArray s, Int)
go2 -- \Z
                | Word16
c forall a. Eq a => a -> a -> Bool
== Word16
92      -> forall {s}. Word16 -> MArray s -> Int -> ST s ()
escape Word16
92  MArray s
marr Int
ix' forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> ST s (MArray s, Int)
go2 -- \\

                | Bool
otherwise    -> forall s. MArray s -> Int -> Word16 -> ST s ()
TA.unsafeWrite MArray s
marr Int
ix' Word16
c forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> ST s (MArray s, Int)
go1

escapeBytes :: ByteString -> ByteString
escapeBytes :: ByteString -> ByteString
escapeBytes (B.PS ForeignPtr Word8
fp Int
s Int
len) = forall a. IO a -> a
unsafeDupablePerformIO forall a b. (a -> b) -> a -> b
$ forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Word8
fp forall a b. (a -> b) -> a -> b
$ \ Ptr Word8
a ->
    Int -> (Ptr Word8 -> IO Int) -> IO ByteString
B.createUptoN (Int
len forall a. Num a => a -> a -> a
* Int
2) forall a b. (a -> b) -> a -> b
$ \ Ptr Word8
b -> do
        Ptr Word8
b' <- Ptr Word8 -> Ptr Word8 -> Ptr Word8 -> IO (Ptr Word8)
loop (Ptr Word8
a forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
s) (Ptr Word8
a forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
s forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
len) Ptr Word8
b
        forall (m :: * -> *) a. Monad m => a -> m a
return (Ptr Word8
b' forall a b. Ptr a -> Ptr b -> Int
`minusPtr` Ptr Word8
b)
  where
    escape :: Word8 -> Ptr Word8 -> IO (Ptr Word8)
    escape :: Word8 -> Ptr Word8 -> IO (Ptr Word8)
escape Word8
c Ptr Word8
p = do
        forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr Word8
p Word8
92
        forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr Word8
p Int
1 Word8
c
        forall (m :: * -> *) a. Monad m => a -> m a
return (Ptr Word8
p forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
2)

    loop :: Ptr Word8 -> Ptr Word8 -> Ptr Word8 -> IO (Ptr Word8)
loop !Ptr Word8
a Ptr Word8
aend !Ptr Word8
b
        | Ptr Word8
a forall a. Eq a => a -> a -> Bool
== Ptr Word8
aend = forall (m :: * -> *) a. Monad m => a -> m a
return Ptr Word8
b
        | Bool
otherwise = do
            Word8
c <- forall a. Storable a => Ptr a -> IO a
peek Ptr Word8
a
            if  | Word8
c forall a. Eq a => a -> a -> Bool
== Word8
0
                    Bool -> Bool -> Bool
|| Word8
c forall a. Eq a => a -> a -> Bool
== Word8
39
                    Bool -> Bool -> Bool
|| Word8
c forall a. Eq a => a -> a -> Bool
== Word8
34 -> Word8 -> Ptr Word8 -> IO (Ptr Word8)
escape Word8
c   Ptr Word8
b forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Ptr Word8 -> Ptr Word8 -> Ptr Word8 -> IO (Ptr Word8)
loop (Ptr Word8
a forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
1) Ptr Word8
aend -- \0 \' \"
                | Word8
c forall a. Eq a => a -> a -> Bool
== Word8
8       -> Word8 -> Ptr Word8 -> IO (Ptr Word8)
escape Word8
98  Ptr Word8
b forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Ptr Word8 -> Ptr Word8 -> Ptr Word8 -> IO (Ptr Word8)
loop (Ptr Word8
a forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
1) Ptr Word8
aend -- \b
                | Word8
c forall a. Eq a => a -> a -> Bool
== Word8
10      -> Word8 -> Ptr Word8 -> IO (Ptr Word8)
escape Word8
110 Ptr Word8
b forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Ptr Word8 -> Ptr Word8 -> Ptr Word8 -> IO (Ptr Word8)
loop (Ptr Word8
a forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
1) Ptr Word8
aend -- \n
                | Word8
c forall a. Eq a => a -> a -> Bool
== Word8
13      -> Word8 -> Ptr Word8 -> IO (Ptr Word8)
escape Word8
114 Ptr Word8
b forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Ptr Word8 -> Ptr Word8 -> Ptr Word8 -> IO (Ptr Word8)
loop (Ptr Word8
a forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
1) Ptr Word8
aend -- \r
                | Word8
c forall a. Eq a => a -> a -> Bool
== Word8
9       -> Word8 -> Ptr Word8 -> IO (Ptr Word8)
escape Word8
116 Ptr Word8
b forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Ptr Word8 -> Ptr Word8 -> Ptr Word8 -> IO (Ptr Word8)
loop (Ptr Word8
a forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
1) Ptr Word8
aend -- \t
                | Word8
c forall a. Eq a => a -> a -> Bool
== Word8
26      -> Word8 -> Ptr Word8 -> IO (Ptr Word8)
escape Word8
90  Ptr Word8
b forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Ptr Word8 -> Ptr Word8 -> Ptr Word8 -> IO (Ptr Word8)
loop (Ptr Word8
a forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
1) Ptr Word8
aend -- \Z
                | Word8
c forall a. Eq a => a -> a -> Bool
== Word8
92      -> Word8 -> Ptr Word8 -> IO (Ptr Word8)
escape Word8
92  Ptr Word8
b forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Ptr Word8 -> Ptr Word8 -> Ptr Word8 -> IO (Ptr Word8)
loop (Ptr Word8
a forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
1) Ptr Word8
aend -- \\

                | Bool
otherwise    -> forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr Word8
b Word8
c forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Ptr Word8 -> Ptr Word8 -> Ptr Word8 -> IO (Ptr Word8)
loop (Ptr Word8
a forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
1) Ptr Word8
aend (Ptr Word8
b forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
1)