{-|
Module      :  Database.Persist.Migration.Utils.Sql
Maintainer  :  Brandon Chinn <brandonchinn178@gmail.com>
Stability   :  experimental
Portability :  portable

Defines helper functions for writing SQL queries.
-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}

module Database.Persist.Migration.Utils.Sql
  ( commas
  , uncommas
  , uncommas'
  , quote
  , MigrateSql(..)
  , executeSql
  , pureSql
  , mapSql
  , concatSql
  ) where

import Control.Monad.IO.Class (MonadIO(..))
#if !MIN_VERSION_base(4,11,0)
import Data.Monoid ((<>))
#endif
import Data.Text (Text)
import qualified Data.Text as Text
import Database.Persist.Sql (PersistValue(..), SqlPersistT)
import qualified Database.Persist.Sql as Persist

-- | Split the given line by commas, ignoring commas within parentheses.
--
-- > commas "a,b,c" == ["a", "b", "c"]
-- > commas "a,b,c (d,e),z" == ["a", "b", "c (d,e)", "z"]
-- > commas "a,b,c (d,e,(f,g)),z" == ["a", "b", "c (d,e,(f,g))", "z"]
commas :: Text -> [Text]
commas :: Text -> [Text]
commas Text
t = [Char] -> [Char] -> [Text] -> Int -> [Text]
forall a.
(Num a, Ord a) =>
[Char] -> [Char] -> [Text] -> a -> [Text]
go (Text -> [Char]
Text.unpack Text
t) [Char]
"" [] (Int
0 :: Int)
  where
    go :: [Char] -> [Char] -> [Text] -> a -> [Text]
go [Char]
src [Char]
buffer [Text]
result a
level =
      let result' :: [Text]
result' = [Text]
result [Text] -> [Text] -> [Text]
forall a. [a] -> [a] -> [a]
++ [[Char] -> Text
Text.pack [Char]
buffer]
      in case [Char]
src of
        [Char]
"" -> [Text]
result'
        Char
',':[Char]
xs | a
level a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
0 -> [Char] -> [Char] -> [Text] -> a -> [Text]
go [Char]
xs [Char]
"" [Text]
result' a
level
        Char
'(':[Char]
xs -> [Char] -> [Char] -> [Text] -> a -> [Text]
go [Char]
xs ([Char]
buffer [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"(") [Text]
result (a
level a -> a -> a
forall a. Num a => a -> a -> a
+ a
1)
        Char
')':[Char]
xs -> [Char] -> [Char] -> [Text] -> a -> [Text]
go [Char]
xs ([Char]
buffer [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
")") [Text]
result (a -> a -> a
forall a. Ord a => a -> a -> a
max a
0 (a -> a) -> a -> a
forall a b. (a -> b) -> a -> b
$ a
level a -> a -> a
forall a. Num a => a -> a -> a
- a
1)
        Char
x:[Char]
xs -> [Char] -> [Char] -> [Text] -> a -> [Text]
go [Char]
xs ([Char]
buffer [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char
x]) [Text]
result a
level

-- | Join the given Text with commas separating each item.
uncommas :: [Text] -> Text
uncommas :: [Text] -> Text
uncommas = Text -> [Text] -> Text
Text.intercalate Text
","

-- | Join the given Text with commas separating each item and quoting them.
uncommas' :: [Text] -> Text
uncommas' :: [Text] -> Text
uncommas' = [Text] -> Text
uncommas ([Text] -> Text) -> ([Text] -> [Text]) -> [Text] -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Text -> Text) -> [Text] -> [Text]
forall a b. (a -> b) -> [a] -> [b]
map Text -> Text
quote

-- | Quote the given Text.
quote :: Text -> Text
quote :: Text -> Text
quote Text
t = Text
"\"" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
t Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"\""

-- | A SQL query (with placeholders) and values to replace those placeholders.
data MigrateSql = MigrateSql
  { MigrateSql -> Text
sqlText :: Text
  , MigrateSql -> [PersistValue]
sqlVals :: [PersistValue]
  } deriving (Int -> MigrateSql -> [Char] -> [Char]
[MigrateSql] -> [Char] -> [Char]
MigrateSql -> [Char]
(Int -> MigrateSql -> [Char] -> [Char])
-> (MigrateSql -> [Char])
-> ([MigrateSql] -> [Char] -> [Char])
-> Show MigrateSql
forall a.
(Int -> a -> [Char] -> [Char])
-> (a -> [Char]) -> ([a] -> [Char] -> [Char]) -> Show a
showList :: [MigrateSql] -> [Char] -> [Char]
$cshowList :: [MigrateSql] -> [Char] -> [Char]
show :: MigrateSql -> [Char]
$cshow :: MigrateSql -> [Char]
showsPrec :: Int -> MigrateSql -> [Char] -> [Char]
$cshowsPrec :: Int -> MigrateSql -> [Char] -> [Char]
Show)

-- | Execute a SQL query.
executeSql :: MonadIO m => MigrateSql -> SqlPersistT m ()
executeSql :: MigrateSql -> SqlPersistT m ()
executeSql MigrateSql{[PersistValue]
Text
sqlVals :: [PersistValue]
sqlText :: Text
sqlVals :: MigrateSql -> [PersistValue]
sqlText :: MigrateSql -> Text
..} = Text -> [PersistValue] -> SqlPersistT m ()
forall (m :: * -> *) backend.
(MonadIO m, BackendCompatible SqlBackend backend) =>
Text -> [PersistValue] -> ReaderT backend m ()
Persist.rawExecute Text
sqlText [PersistValue]
sqlVals

-- | Create a MigrateSql from the given Text.
pureSql :: Text -> MigrateSql
pureSql :: Text -> MigrateSql
pureSql Text
sql = Text -> [PersistValue] -> MigrateSql
MigrateSql Text
sql []

-- | Map the SQL text with the given function.
mapSql :: (Text -> Text) -> MigrateSql -> MigrateSql
mapSql :: (Text -> Text) -> MigrateSql -> MigrateSql
mapSql Text -> Text
f MigrateSql
sql = MigrateSql
sql { sqlText :: Text
sqlText = Text -> Text
f (Text -> Text) -> Text -> Text
forall a b. (a -> b) -> a -> b
$ MigrateSql -> Text
sqlText MigrateSql
sql }

-- | Concatenate the given MigrateSql queries with the given combining function.
concatSql :: ([Text] -> Text) -> [MigrateSql] -> MigrateSql
concatSql :: ([Text] -> Text) -> [MigrateSql] -> MigrateSql
concatSql [Text] -> Text
f [MigrateSql]
queries = MigrateSql :: Text -> [PersistValue] -> MigrateSql
MigrateSql
  { sqlText :: Text
sqlText = [Text] -> Text
f ([Text] -> Text) -> [Text] -> Text
forall a b. (a -> b) -> a -> b
$ (MigrateSql -> Text) -> [MigrateSql] -> [Text]
forall a b. (a -> b) -> [a] -> [b]
map MigrateSql -> Text
sqlText [MigrateSql]
queries
  , sqlVals :: [PersistValue]
sqlVals = (MigrateSql -> [PersistValue]) -> [MigrateSql] -> [PersistValue]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap MigrateSql -> [PersistValue]
sqlVals [MigrateSql]
queries
  }