module PostgREST.QueryBuilder (
addRelations
, addJoinConditions
, callProc
, createReadStatement
, createWriteStatement
, inTransaction
, operators
, pgFmtIdent
, pgFmtLit
, requestToQuery
, requestToCountQuery
, sourceCTEName
, unquoted
, ResultsWithCount
, Isolation(..)
) where
import qualified Hasql.Query as H
import qualified Hasql.Session as H
import qualified Hasql.Encoders as HE
import qualified Hasql.Decoders as HD
import qualified Data.Aeson as JSON
import Data.Int (Int64)
import PostgREST.RangeQuery (NonnegRange, rangeLimit, rangeOffset)
import Control.Error (note, fromMaybe, mapMaybe)
import Data.Functor.Contravariant (contramap)
import qualified Data.HashMap.Strict as HM
import Data.List (find, (\\))
import Data.Monoid ((<>))
import Data.Text (Text, intercalate, unwords, replace, isInfixOf, toLower, split)
import qualified Data.Text as T (map, takeWhile)
import Data.String.Conversions (cs)
import Control.Applicative ((<|>))
import Control.Monad (join)
import Data.Tree (Tree(..))
import qualified Data.Vector as V
import PostgREST.Types
import qualified Data.Map as M
import Text.InterpolatedString.Perl6 (qc)
import Text.Regex.TDFA ((=~))
import qualified Data.ByteString.Char8 as BS
import Data.Scientific ( FPFormat (..)
, formatScientific
, isInteger
)
import Prelude hiding (unwords)
import PostgREST.ApiRequest (PreferRepresentation (..))
type ResultsWithCount = (Maybe Int64, Int64, BS.ByteString, BS.ByteString)
decodeStandard :: HD.Result ResultsWithCount
decodeStandard =
HD.singleRow standardRow
where
standardRow = (,,,) <$> HD.nullableValue HD.int8 <*> HD.value HD.int8
<*> HD.value HD.bytea <*> HD.value HD.bytea
decodeStandardMay :: HD.Result (Maybe ResultsWithCount)
decodeStandardMay =
HD.maybeRow standardRow
where
standardRow = (,,,) <$> HD.nullableValue HD.int8 <*> HD.value HD.int8
<*> HD.value HD.bytea <*> HD.value HD.bytea
encodeUniformObjs :: HE.Params UniformObjects
encodeUniformObjs =
contramap (JSON.Array . V.map JSON.Object . unUniformObjects) (HE.value HE.json)
createReadStatement :: SqlQuery -> SqlQuery -> NonnegRange -> Bool -> Bool -> Bool ->
H.Query () ResultsWithCount
createReadStatement selectQuery countQuery range isSingle countTotal asCsv =
H.statement sql HE.unit decodeStandard True
where
sql = [qc|
WITH {sourceCTEName} AS ({selectQuery}) SELECT {cols}
FROM ( SELECT * FROM {sourceCTEName} {limitF range}) t |]
countResultF = if countTotal then "("<>countQuery<>")" else "null"
cols = intercalate ", " [
countResultF <> " AS total_result_set",
"pg_catalog.count(t) AS page_total",
"'' AS header",
bodyF <> " AS body"
]
bodyF
| asCsv = asCsvF
| isSingle = asJsonSingleF
| otherwise = asJsonF
createWriteStatement :: QualifiedIdentifier -> SqlQuery -> SqlQuery -> Bool ->
PreferRepresentation -> [Text] -> Bool -> Payload ->
H.Query UniformObjects (Maybe ResultsWithCount)
createWriteStatement _ _ _ _ _ _ _ (PayloadParseError _) = undefined
createWriteStatement _ _ mutateQuery _ None
_ _ (PayloadJSON (UniformObjects _)) =
H.statement sql encodeUniformObjs decodeStandardMay True
where
sql = [qc|
WITH {sourceCTEName} AS ({mutateQuery})
SELECT '', 0, '', '' |]
createWriteStatement qi _ mutateQuery isSingle HeadersOnly
pKeys _ (PayloadJSON (UniformObjects _)) =
H.statement sql encodeUniformObjs decodeStandardMay True
where
sql = [qc|
WITH {sourceCTEName} AS ({mutateQuery} RETURNING {fromQi qi}.*)
SELECT {cols}
FROM (SELECT 1 FROM {sourceCTEName}) t |]
cols = intercalate ", " [
"'' AS total_result_set",
"pg_catalog.count(t) AS page_total",
if isSingle then locationF pKeys else "''",
"''"
]
createWriteStatement qi selectQuery mutateQuery isSingle Full
pKeys asCsv (PayloadJSON (UniformObjects _)) =
H.statement sql encodeUniformObjs decodeStandardMay True
where
sql = [qc|
WITH {sourceCTEName} AS ({mutateQuery} RETURNING {fromQi qi}.*)
SELECT {cols}
FROM ({selectQuery}) t |]
cols = intercalate ", " [
"'' AS total_result_set",
"pg_catalog.count(t) AS page_total",
if isSingle then locationF pKeys else "''" <> " AS header",
bodyF <> " AS body"
]
bodyF
| asCsv = asCsvF
| isSingle = asJsonSingleF
| otherwise = asJsonF
addRelations :: Schema -> [Relation] -> Maybe ReadRequest -> ReadRequest -> Either Text ReadRequest
addRelations schema allRelations parentNode node@(Node readNode@(query, (name, _)) forest) =
case parentNode of
(Just (Node (Select{from=[parentTable]}, (_, _)) _)) -> Node <$> (addRel readNode <$> rel) <*> updatedForest
where
rel = note ("no relation between " <> parentTable <> " and " <> name)
$ findRelationByTable schema name parentTable
<|> findRelationByColumn schema parentTable name
addRel :: (ReadQuery, (NodeName, Maybe Relation)) -> Relation -> (ReadQuery, (NodeName, Maybe Relation))
addRel (query', (n, _)) r = (query' {from=fromRelation}, (n, Just r))
where fromRelation = map (\t -> if t == n then tableName (relTable r) else t) (from query')
_ -> Node (query, (name, Nothing)) <$> updatedForest
where
updatedForest = mapM (addRelations schema allRelations (Just node)) forest
findRelationByTable s t1 t2 =
find (\r -> s == tableSchema (relTable r) && s == tableSchema (relFTable r) && t1 == tableName (relTable r) && t2 == tableName (relFTable r)) allRelations
findRelationByColumn s t c =
find (\r -> s == tableSchema (relTable r) && s == tableSchema (relFTable r) && t == tableName (relFTable r) && length (relFColumns r) == 1 && c `colMatches` (colName . head . relFColumns) r) allRelations
where n `colMatches` rc = (cs ("^" <> rc <> "_?(?:|[iI][dD]|[fF][kK])$") :: BS.ByteString) =~ (cs n :: BS.ByteString)
addJoinConditions :: Schema -> ReadRequest -> Either Text ReadRequest
addJoinConditions schema (Node (query, (n, r)) forest) =
case r of
Nothing -> Node (updatedQuery, (n,r)) <$> updatedForest
Just rel@Relation{relType=Child} -> Node (addCond updatedQuery (getJoinConditions rel),(n,r)) <$> updatedForest
Just Relation{relType=Parent} -> Node (updatedQuery, (n,r)) <$> updatedForest
Just rel@Relation{relType=Many, relLTable=(Just linkTable)} ->
Node (qq, (n, r)) <$> updatedForest
where
query' = addCond updatedQuery (getJoinConditions rel)
qq = query'{from=tableName linkTable : from query'}
_ -> Left "unknown relation"
where
updatedQuery = foldr (flip addCond) query parentJoinConditions
where
parentJoinConditions = map (getJoinConditions . snd) parents
parents = mapMaybe (getParents . rootLabel) forest
getParents (_, (tbl, Just rel@Relation{relType=Parent})) = Just (tbl, rel)
getParents _ = Nothing
updatedForest = mapM (addJoinConditions schema) forest
addCond query' con = query'{flt_=con ++ flt_ query'}
callProc :: QualifiedIdentifier -> JSON.Object -> H.Query () (Maybe JSON.Value)
callProc qi params =
H.statement sql HE.unit decodeObj True
where
sql = [qc| SELECT array_to_json(
coalesce(array_agg(row_to_json(t)), '\{}')
)::character varying
from ({_callSql}) t |]
_args = intercalate "," $ map _assignment (HM.toList params)
_assignment (n,v) = pgFmtIdent n <> ":=" <> insertableValue v
_callSql = [qc| select * from {fromQi qi}({_args}) |] :: BS.ByteString
decodeObj = HD.maybeRow (HD.value HD.json)
operators :: [(Text, SqlFragment)]
operators = [
("eq", "="),
("gte", ">="),
("gt", ">"),
("lte", "<="),
("lt", "<"),
("neq", "<>"),
("like", "like"),
("ilike", "ilike"),
("in", "in"),
("notin", "not in"),
("isnot", "is not"),
("is", "is"),
("@@", "@@"),
("@>", "@>"),
("<@", "<@")
]
pgFmtIdent :: SqlFragment -> SqlFragment
pgFmtIdent x = "\"" <> replace "\"" "\"\"" (trimNullChars $ cs x) <> "\""
pgFmtLit :: SqlFragment -> SqlFragment
pgFmtLit x =
let trimmed = trimNullChars x
escaped = "'" <> replace "'" "''" trimmed <> "'"
slashed = replace "\\" "\\\\" escaped in
if "\\" `isInfixOf` escaped
then "E" <> slashed
else slashed
requestToCountQuery :: Schema -> DbRequest -> SqlQuery
requestToCountQuery _ (DbMutate _) = undefined
requestToCountQuery schema (DbRead (Node (Select _ _ conditions _, (mainTbl, _)) _)) =
unwords [
"SELECT pg_catalog.count(1)",
"FROM ", fromQi $ QualifiedIdentifier schema mainTbl,
("WHERE " <> intercalate " AND " ( map (pgFmtCondition (QualifiedIdentifier schema mainTbl)) localConditions )) `emptyOnNull` localConditions
]
where
fn Filter{value=VText _} = True
fn Filter{value=VForeignKey _ _} = False
localConditions = filter fn conditions
requestToQuery :: Schema -> DbRequest -> SqlQuery
requestToQuery _ (DbMutate (Insert _ (PayloadParseError _))) = undefined
requestToQuery _ (DbMutate (Update _ (PayloadParseError _) _)) = undefined
requestToQuery schema (DbRead (Node (Select colSelects tbls conditions ord, (nodeName, maybeRelation)) forest)) =
query
where
mainTbl = fromMaybe nodeName (tableName . relTable <$> maybeRelation)
tblSchema tbl = if tbl == sourceCTEName then "" else schema
qi = QualifiedIdentifier (tblSchema mainTbl) mainTbl
toQi t = QualifiedIdentifier (tblSchema t) t
query = unwords [
"SELECT ", intercalate ", " (map (pgFmtSelectItem qi) colSelects ++ selects),
"FROM ", intercalate ", " (map (fromQi . toQi) tbls),
unwords (map joinStr joins),
("WHERE " <> intercalate " AND " ( map (pgFmtCondition qi ) localConditions )) `emptyOnNull` localConditions,
orderF (fromMaybe [] ord)
]
orderF ts =
if null ts
then ""
else "ORDER BY " <> clause
where
clause = intercalate "," (map queryTerm ts)
queryTerm :: OrderTerm -> Text
queryTerm t = " "
<> cs (pgFmtColumn qi $ otTerm t) <> " "
<> (cs.show) (otDirection t) <> " "
<> maybe "" (cs.show) (otNullOrder t) <> " "
(joins, selects) = foldr getQueryParts ([],[]) forest
parentTables = map snd joins
parentConditions = join $ map (( `filter` conditions ) . filterParentConditions) parentTables
localConditions = conditions \\ parentConditions
joinStr :: (SqlFragment, TableName) -> SqlFragment
joinStr (sql, t) = "LEFT OUTER JOIN " <> sql <> " ON " <>
intercalate " AND " ( map (pgFmtCondition qi ) joinConditions )
where
joinConditions = filter (filterParentConditions t) conditions
filterParentConditions parentTable (Filter _ _ (VForeignKey (QualifiedIdentifier "" t) _)) = parentTable == t
filterParentConditions _ _ = False
getQueryParts :: Tree ReadNode -> ([(SqlFragment, TableName)], [SqlFragment]) -> ([(SqlFragment,TableName)], [SqlFragment])
getQueryParts (Node n@(_, (name, Just Relation{relType=Child,relTable=Table{tableName=table}})) forst) (j,s) = (j,sel:s)
where
sel = "COALESCE(("
<> "SELECT array_to_json(array_agg(row_to_json("<>pgFmtIdent table<>"))) "
<> "FROM (" <> subquery <> ") " <> pgFmtIdent table
<> "), '[]') AS " <> pgFmtIdent name
where subquery = requestToQuery schema (DbRead (Node n forst))
getQueryParts (Node n@(_, (name, Just Relation{relType=Parent,relTable=Table{tableName=table}})) forst) (j,s) = (joi:j,sel:s)
where
sel = "row_to_json(" <> pgFmtIdent table <> ".*) AS "<>pgFmtIdent name
joi = ("( " <> subquery <> " ) AS " <> pgFmtIdent table, table)
where subquery = requestToQuery schema (DbRead (Node n forst))
getQueryParts (Node n@(_, (name, Just Relation{relType=Many,relTable=Table{tableName=table}})) forst) (j,s) = (j,sel:s)
where
sel = "COALESCE (("
<> "SELECT array_to_json(array_agg(row_to_json("<>pgFmtIdent table<>"))) "
<> "FROM (" <> subquery <> ") " <> pgFmtIdent table
<> "), '[]') AS " <> pgFmtIdent name
where subquery = requestToQuery schema (DbRead (Node n forst))
getQueryParts (Node (_,(_,Nothing)) _) _ = undefined
requestToQuery schema (DbMutate (Insert mainTbl (PayloadJSON (UniformObjects rows)))) =
let qi = QualifiedIdentifier schema mainTbl
cols = map pgFmtIdent $ fromMaybe [] (HM.keys <$> (rows V.!? 0))
colsString = intercalate ", " cols in
unwords [
"INSERT INTO ", fromQi qi,
" (" <> colsString <> ")" <>
" SELECT " <> colsString <>
" FROM json_populate_recordset(null::" , fromQi qi, ", $1)"
]
requestToQuery schema (DbMutate (Update mainTbl (PayloadJSON (UniformObjects rows)) conditions)) =
case rows V.!? 0 of
Just obj ->
let assignments = map
(\(k,v) -> pgFmtIdent k <> "=" <> insertableValue v) $ HM.toList obj in
unwords [
"UPDATE ", fromQi qi,
" SET " <> intercalate "," assignments <> " ",
("WHERE " <> intercalate " AND " ( map (pgFmtCondition qi ) conditions )) `emptyOnNull` conditions
]
Nothing -> undefined
where
qi = QualifiedIdentifier schema mainTbl
requestToQuery schema (DbMutate (Delete mainTbl conditions)) =
query
where
qi = QualifiedIdentifier schema mainTbl
query = unwords [
"DELETE FROM ", fromQi qi,
("WHERE " <> intercalate " AND " ( map (pgFmtCondition qi ) conditions )) `emptyOnNull` conditions
]
sourceCTEName :: SqlFragment
sourceCTEName = "pg_source"
unquoted :: JSON.Value -> Text
unquoted (JSON.String t) = t
unquoted (JSON.Number n) =
cs $ formatScientific Fixed (if isInteger n then Just 0 else Nothing) n
unquoted (JSON.Bool b) = cs . show $ b
unquoted v = cs $ JSON.encode v
asCsvF :: SqlFragment
asCsvF = asCsvHeaderF <> " || '\n' || " <> asCsvBodyF
where
asCsvHeaderF =
"(SELECT coalesce(string_agg(a.k, ','), '')" <>
" FROM (" <>
" SELECT json_object_keys(r)::TEXT as k" <>
" FROM ( " <>
" SELECT row_to_json(hh) as r from " <> sourceCTEName <> " as hh limit 1" <>
" ) s" <>
" ) a" <>
")"
asCsvBodyF = "coalesce(string_agg(substring(t::text, 2, length(t::text) - 2), '\n'), '')"
asJsonF :: SqlFragment
asJsonF = "coalesce(array_to_json(array_agg(row_to_json(t))), '[]')::character varying"
asJsonSingleF :: SqlFragment
asJsonSingleF = "coalesce(string_agg(row_to_json(t)::text, ','), '')::character varying "
locationF :: [Text] -> SqlFragment
locationF pKeys =
"(" <>
" WITH s AS (SELECT row_to_json(ss) as r from " <> sourceCTEName <> " as ss limit 1)" <>
" SELECT string_agg(json_data.key || '=' || coalesce( 'eq.' || json_data.value, 'is.null'), '&')" <>
" FROM s, json_each_text(s.r) AS json_data" <>
(
if null pKeys
then ""
else " WHERE json_data.key IN ('" <> intercalate "','" pKeys <> "')"
) <> ")"
limitF :: NonnegRange -> SqlFragment
limitF r = "LIMIT " <> limit <> " OFFSET " <> offset
where
limit = maybe "ALL" (cs . show) $ rangeLimit r
offset = cs . show $ rangeOffset r
fromQi :: QualifiedIdentifier -> SqlFragment
fromQi t = (if s == "" then "" else pgFmtIdent s <> ".") <> pgFmtIdent n
where
n = qiName t
s = qiSchema t
getJoinConditions :: Relation -> [Filter]
getJoinConditions (Relation t cols ft fcs typ lt lc1 lc2) =
case typ of
Child -> zipWith (toFilter tN ftN) cols fcs
Parent -> zipWith (toFilter tN ftN) cols fcs
Many -> zipWith (toFilter tN ltN) cols (fromMaybe [] lc1) ++ zipWith (toFilter ftN ltN) fcs (fromMaybe [] lc2)
where
s = if typ == Parent then "" else tableSchema t
tN = tableName t
ftN = tableName ft
ltN = fromMaybe "" (tableName <$> lt)
toFilter :: Text -> Text -> Column -> Column -> Filter
toFilter tb ftb c fc = Filter (colName c, Nothing) "=" (VForeignKey (QualifiedIdentifier s tb) (ForeignKey fc{colTable=(colTable fc){tableName=ftb}}))
emptyOnNull :: Text -> [a] -> Text
emptyOnNull val x = if null x then "" else val
insertableValue :: JSON.Value -> SqlFragment
insertableValue JSON.Null = "null"
insertableValue v = (<> "::unknown") . pgFmtLit $ unquoted v
whiteList :: Text -> SqlFragment
whiteList val = fromMaybe
(cs (pgFmtLit val) <> "::unknown ")
(find ((==) . toLower $ val) ["null","true","false"])
pgFmtColumn :: QualifiedIdentifier -> Text -> SqlFragment
pgFmtColumn table "*" = fromQi table <> ".*"
pgFmtColumn table c = fromQi table <> "." <> pgFmtIdent c
pgFmtField :: QualifiedIdentifier -> Field -> SqlFragment
pgFmtField table (c, jp) = pgFmtColumn table c <> pgFmtJsonPath jp
pgFmtSelectItem :: QualifiedIdentifier -> SelectItem -> SqlFragment
pgFmtSelectItem table (f@(_, jp), Nothing) = pgFmtField table f <> pgFmtAsJsonPath jp
pgFmtSelectItem table (f@(_, jp), Just cast ) = "CAST (" <> pgFmtField table f <> " AS " <> cast <> " )" <> pgFmtAsJsonPath jp
pgFmtCondition :: QualifiedIdentifier -> Filter -> SqlFragment
pgFmtCondition table (Filter (col,jp) ops val) =
notOp <> " " <> sqlCol <> " " <> pgFmtOperator opCode <> " " <>
if opCode `elem` ["is","isnot"] then whiteList (getInner val) else sqlValue
where
headPredicate:rest = split (=='.') ops
hasNot caseTrue caseFalse = if headPredicate == "not" then caseTrue else caseFalse
opCode = hasNot (head rest) headPredicate
notOp = hasNot headPredicate ""
sqlCol = case val of
VText _ -> pgFmtColumn table col <> pgFmtJsonPath jp
VForeignKey qi _ -> pgFmtColumn qi col
sqlValue = valToStr val
getInner v = case v of
VText s -> s
_ -> ""
valToStr v = case v of
VText s -> pgFmtValue opCode s
VForeignKey (QualifiedIdentifier s _) (ForeignKey Column{colTable=Table{tableName=ft}, colName=fc}) -> pgFmtColumn qi fc
where qi = QualifiedIdentifier (if ft == sourceCTEName then "" else s) ft
_ -> ""
pgFmtValue :: Text -> Text -> SqlFragment
pgFmtValue opCode val =
case opCode of
"like" -> unknownLiteral $ T.map star val
"ilike" -> unknownLiteral $ T.map star val
"in" -> "(" <> intercalate ", " (map unknownLiteral $ split (==',') val) <> ") "
"notin" -> "(" <> intercalate ", " (map unknownLiteral $ split (==',') val) <> ") "
"@@" -> "to_tsquery(" <> unknownLiteral val <> ") "
_ -> unknownLiteral val
where
star c = if c == '*' then '%' else c
unknownLiteral = (<> "::unknown ") . pgFmtLit
pgFmtOperator :: Text -> SqlFragment
pgFmtOperator opCode = fromMaybe "=" $ M.lookup opCode operatorsMap
where
operatorsMap = M.fromList operators
pgFmtJsonPath :: Maybe JsonPath -> SqlFragment
pgFmtJsonPath (Just [x]) = "->>" <> pgFmtLit x
pgFmtJsonPath (Just (x:xs)) = "->" <> pgFmtLit x <> pgFmtJsonPath ( Just xs )
pgFmtJsonPath _ = ""
pgFmtAsJsonPath :: Maybe JsonPath -> SqlFragment
pgFmtAsJsonPath Nothing = ""
pgFmtAsJsonPath (Just xx) = " AS " <> last xx
trimNullChars :: Text -> Text
trimNullChars = T.takeWhile (/= '\x0')
data Isolation = ReadCommitted | RepeatableRead | Serializable
inTransaction :: Isolation -> H.Session a -> H.Session a
inTransaction lvl f = do
H.sql $ "begin " <> isolate <> ";"
r <- f
H.sql "commit;"
return r
where
isolate = case lvl of
ReadCommitted -> "ISOLATION LEVEL READ COMMITTED"
RepeatableRead -> "ISOLATION LEVEL REPEATABLE READ"
Serializable -> "ISOLATION LEVEL SERIALIZABLE"