module Database.PostgreSQL.PQTypes.SQL
  ( SQL
  , mkSQL
  , sqlParam
  , (<?>)
  , isSqlEmpty
  ) where

import Control.Concurrent.MVar
import Data.ByteString.Char8 qualified as BS
import Data.ByteString.Unsafe qualified as BS
import Data.Foldable qualified as F
import Data.Monoid
import Data.Semigroup qualified as SG
import Data.Sequence qualified as S
import Data.String
import Data.Text qualified as T
import Data.Text.Encoding qualified as T
import Foreign.Marshal.Alloc
import TextShow

import Data.Monoid.Utils
import Database.PostgreSQL.PQTypes.Format
import Database.PostgreSQL.PQTypes.Internal.C.Put
import Database.PostgreSQL.PQTypes.Internal.Utils
import Database.PostgreSQL.PQTypes.SQL.Class
import Database.PostgreSQL.PQTypes.ToSQL

data SqlChunk where
  SqlString :: !T.Text -> SqlChunk
  SqlParam :: forall t. (Show t, ToSQL t) => !t -> SqlChunk

-- | Primary SQL type that supports efficient
-- concatenation and variable number of parameters.
newtype SQL = SQL (S.Seq SqlChunk)

unSQL :: SQL -> [SqlChunk]
unSQL :: SQL -> [SqlChunk]
unSQL (SQL Seq SqlChunk
chunks) = Seq SqlChunk -> [SqlChunk]
forall a. Seq a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
F.toList Seq SqlChunk
chunks

----------------------------------------

-- | Construct 'SQL' from 'String'.
instance IsString SQL where
  fromString :: String -> SQL
fromString = Text -> SQL
mkSQL (Text -> SQL) -> (String -> Text) -> String -> SQL
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Text
T.pack

instance IsSQL SQL where
  withSQL :: forall r.
SQL -> ParamAllocator -> (Ptr PGparam -> CString -> IO r) -> IO r
withSQL SQL
sql pa :: ParamAllocator
pa@(ParamAllocator forall r. (Ptr PGparam -> IO r) -> IO r
allocParam) Ptr PGparam -> CString -> IO r
execute = do
    (Ptr PGerror -> IO r) -> IO r
forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca ((Ptr PGerror -> IO r) -> IO r) -> (Ptr PGerror -> IO r) -> IO r
forall a b. (a -> b) -> a -> b
$ \Ptr PGerror
err -> (Ptr PGparam -> IO r) -> IO r
forall r. (Ptr PGparam -> IO r) -> IO r
allocParam ((Ptr PGparam -> IO r) -> IO r) -> (Ptr PGparam -> IO r) -> IO r
forall a b. (a -> b) -> a -> b
$ \Ptr PGparam
param -> do
      MVar Int
nums <- Int -> IO (MVar Int)
forall a. a -> IO (MVar a)
newMVar (Int
1 :: Int)
      Text
query <- [Text] -> Text
T.concat ([Text] -> Text) -> IO [Text] -> IO Text
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SqlChunk -> IO Text) -> [SqlChunk] -> IO [Text]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Ptr PGparam -> Ptr PGerror -> MVar Int -> SqlChunk -> IO Text
f Ptr PGparam
param Ptr PGerror
err MVar Int
nums) (SQL -> [SqlChunk]
unSQL SQL
sql)
      ByteString -> (CString -> IO r) -> IO r
forall a. ByteString -> (CString -> IO a) -> IO a
BS.useAsCString (Text -> ByteString
T.encodeUtf8 Text
query) (Ptr PGparam -> CString -> IO r
execute Ptr PGparam
param)
    where
      f :: Ptr PGparam -> Ptr PGerror -> MVar Int -> SqlChunk -> IO Text
f Ptr PGparam
param Ptr PGerror
err MVar Int
nums SqlChunk
chunk = case SqlChunk
chunk of
        SqlString Text
s -> Text -> IO Text
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Text
s
        SqlParam (t
v :: t) -> t -> ParamAllocator -> (Ptr (PQDest t) -> IO Text) -> IO Text
forall r. t -> ParamAllocator -> (Ptr (PQDest t) -> IO r) -> IO r
forall t r.
ToSQL t =>
t -> ParamAllocator -> (Ptr (PQDest t) -> IO r) -> IO r
toSQL t
v ParamAllocator
pa ((Ptr (PQDest t) -> IO Text) -> IO Text)
-> (Ptr (PQDest t) -> IO Text) -> IO Text
forall a b. (a -> b) -> a -> b
$ \Ptr (PQDest t)
base ->
          ByteString -> (CString -> IO Text) -> IO Text
forall a. ByteString -> (CString -> IO a) -> IO a
BS.unsafeUseAsCString (forall t. PQFormat t => ByteString
pqFormat0 @t) ((CString -> IO Text) -> IO Text)
-> (CString -> IO Text) -> IO Text
forall a b. (a -> b) -> a -> b
$ \CString
fmt -> do
            Ptr PGerror -> String -> CInt -> IO ()
verifyPQTRes Ptr PGerror
err String
"withSQL (SQL)" (CInt -> IO ()) -> IO CInt -> IO ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Ptr PGparam -> Ptr PGerror -> CString -> Ptr (PQDest t) -> IO CInt
forall t. Ptr PGparam -> Ptr PGerror -> CString -> Ptr t -> IO CInt
c_PQputf1 Ptr PGparam
param Ptr PGerror
err CString
fmt Ptr (PQDest t)
base
            MVar Int -> (Int -> IO (Int, Text)) -> IO Text
forall a b. MVar a -> (a -> IO (a, b)) -> IO b
modifyMVar MVar Int
nums ((Int -> IO (Int, Text)) -> IO Text)
-> (Int -> IO (Int, Text)) -> IO Text
forall a b. (a -> b) -> a -> b
$ \Int
n -> (Int, Text) -> IO (Int, Text)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((Int, Text) -> IO (Int, Text))
-> (Int -> (Int, Text)) -> Int -> IO (Int, Text)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (,Text
"$" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Int -> Text
forall a. TextShow a => a -> Text
showt Int
n) (Int -> IO (Int, Text)) -> Int -> IO (Int, Text)
forall a b. (a -> b) -> a -> b
$! Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1

instance SG.Semigroup SQL where
  SQL Seq SqlChunk
a <> :: SQL -> SQL -> SQL
<> SQL Seq SqlChunk
b = Seq SqlChunk -> SQL
SQL (Seq SqlChunk
a Seq SqlChunk -> Seq SqlChunk -> Seq SqlChunk
forall a. Seq a -> Seq a -> Seq a
S.>< Seq SqlChunk
b)

instance Monoid SQL where
  mempty :: SQL
mempty = Text -> SQL
mkSQL Text
T.empty
  mappend :: SQL -> SQL -> SQL
mappend = SQL -> SQL -> SQL
forall a. Semigroup a => a -> a -> a
(SG.<>)

instance Show SQL where
  showsPrec :: Int -> SQL -> ShowS
showsPrec Int
n SQL
sql = (String
"SQL " String -> ShowS
forall a. [a] -> [a] -> [a]
++) ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> String -> ShowS
forall a. Show a => Int -> a -> ShowS
showsPrec Int
n (String -> ShowS) -> (SQL -> String) -> SQL -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SqlChunk -> String) -> [SqlChunk] -> String
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap SqlChunk -> String
conv ([SqlChunk] -> String) -> (SQL -> [SqlChunk]) -> SQL -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SQL -> [SqlChunk]
unSQL (SQL -> ShowS) -> SQL -> ShowS
forall a b. (a -> b) -> a -> b
$ SQL
sql)
    where
      conv :: SqlChunk -> String
conv (SqlString Text
s) = Text -> String
T.unpack Text
s
      conv (SqlParam t
v) = String
"<" String -> ShowS
forall a. [a] -> [a] -> [a]
++ t -> String
forall a. Show a => a -> String
show t
v String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
">"

----------------------------------------

-- | Convert a 'Text' SQL string to the 'SQL' type.
mkSQL :: T.Text -> SQL
mkSQL :: Text -> SQL
mkSQL = Seq SqlChunk -> SQL
SQL (Seq SqlChunk -> SQL) -> (Text -> Seq SqlChunk) -> Text -> SQL
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SqlChunk -> Seq SqlChunk
forall a. a -> Seq a
S.singleton (SqlChunk -> Seq SqlChunk)
-> (Text -> SqlChunk) -> Text -> Seq SqlChunk
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> SqlChunk
SqlString

-- | Embed parameter value inside 'SQL'.
sqlParam :: (Show t, ToSQL t) => t -> SQL
sqlParam :: forall t. (Show t, ToSQL t) => t -> SQL
sqlParam = Seq SqlChunk -> SQL
SQL (Seq SqlChunk -> SQL) -> (t -> Seq SqlChunk) -> t -> SQL
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SqlChunk -> Seq SqlChunk
forall a. a -> Seq a
S.singleton (SqlChunk -> Seq SqlChunk) -> (t -> SqlChunk) -> t -> Seq SqlChunk
forall b c a. (b -> c) -> (a -> b) -> a -> c
. t -> SqlChunk
forall t. (Show t, ToSQL t) => t -> SqlChunk
SqlParam

-- | Embed parameter value inside existing 'SQL'. Example:
--
-- > f :: Int32 -> String -> SQL
-- > f idx name = "SELECT foo FROM bar WHERE id =" <?> idx <+> "AND name =" <?> name
(<?>) :: (Show t, ToSQL t) => SQL -> t -> SQL
SQL
s <?> :: forall t. (Show t, ToSQL t) => SQL -> t -> SQL
<?> t
v = SQL
s SQL -> SQL -> SQL
forall m. (IsString m, Monoid m) => m -> m -> m
<+> t -> SQL
forall t. (Show t, ToSQL t) => t -> SQL
sqlParam t
v

infixr 7 <?>

----------------------------------------

-- | Test whether an 'SQL' is empty.
isSqlEmpty :: SQL -> Bool
isSqlEmpty :: SQL -> Bool
isSqlEmpty (SQL Seq SqlChunk
chunks) = All -> Bool
getAll (All -> Bool) -> All -> Bool
forall a b. (a -> b) -> a -> b
$ (SqlChunk -> All) -> Seq SqlChunk -> All
forall m a. Monoid m => (a -> m) -> Seq a -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
F.foldMap (Bool -> All
All (Bool -> All) -> (SqlChunk -> Bool) -> SqlChunk -> All
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SqlChunk -> Bool
cmp) Seq SqlChunk
chunks
  where
    cmp :: SqlChunk -> Bool
cmp (SqlString Text
s) = Text
s Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
== Text
T.empty
    cmp (SqlParam t
_) = Bool
False