module PostgREST.QueryBuilder (
callProc
, createReadStatement
, createWriteStatement
, getJoinConditions
, pgFmtIdent
, pgFmtLit
, requestToQuery
, requestToCountQuery
, sourceCTEName
, unquoted
, ResultsWithCount
, pgFmtEnvVar
) where
import qualified Hasql.Query as H
import qualified Hasql.Encoders as HE
import qualified Hasql.Decoders as HD
import qualified Data.Aeson as JSON
import PostgREST.RangeQuery (NonnegRange, rangeLimit, rangeOffset, allRange)
import Data.Functor.Contravariant (contramap)
import qualified Data.HashMap.Strict as HM
import Data.Maybe
import Data.Text (intercalate, unwords, replace, isInfixOf, toLower)
import qualified Data.Text as T (map, takeWhile, null)
import qualified Data.Text.Encoding as T
import Data.Tree (Tree(..))
import qualified Data.Vector as V
import PostgREST.Types
import Text.InterpolatedString.Perl6 (qc)
import qualified Data.ByteString.Char8 as BS
import Data.Scientific ( FPFormat (..)
, formatScientific
, isInteger
)
import Protolude hiding (from, intercalate, ord, cast)
import PostgREST.ApiRequest (PreferRepresentation (..))
type ResultsWithCount = (Maybe Int64, Int64, [BS.ByteString], BS.ByteString)
standardRow :: HD.Row ResultsWithCount
standardRow = (,,,) <$> HD.nullableValue HD.int8 <*> HD.value HD.int8
<*> HD.value header <*> HD.value HD.bytea
where
header = HD.array $ HD.arrayDimension replicateM $ HD.arrayValue HD.bytea
noLocationF :: Text
noLocationF = "array[]::text[]"
decodeStandard :: HD.Result ResultsWithCount
decodeStandard =
HD.singleRow standardRow
decodeStandardMay :: HD.Result (Maybe ResultsWithCount)
decodeStandardMay =
HD.maybeRow standardRow
encodeUniformObjs :: HE.Params PayloadJSON
encodeUniformObjs =
contramap (JSON.Array . V.map JSON.Object . unPayloadJSON) (HE.value HE.json)
createReadStatement :: SqlQuery -> SqlQuery -> Bool -> Bool -> Bool -> Maybe FieldName ->
H.Query () ResultsWithCount
createReadStatement selectQuery countQuery isSingle countTotal asCsv binaryField =
unicodeStatement sql HE.unit decodeStandard False
where
sql = [qc|
WITH {sourceCTEName} AS ({selectQuery}) SELECT {cols}
FROM ( SELECT * FROM {sourceCTEName}) _postgrest_t |]
countResultF = if countTotal then "("<>countQuery<>")" else "null"
cols = intercalate ", " [
countResultF <> " AS total_result_set",
"pg_catalog.count(_postgrest_t) AS page_total",
noLocationF <> " AS header",
bodyF <> " AS body"
]
bodyF
| asCsv = asCsvF
| isSingle = asJsonSingleF
| isJust binaryField = asBinaryF $ fromJust binaryField
| otherwise = asJsonF
createWriteStatement :: SqlQuery -> SqlQuery -> Bool -> Bool -> Bool ->
PreferRepresentation -> [Text] ->
H.Query PayloadJSON (Maybe ResultsWithCount)
createWriteStatement selectQuery mutateQuery wantSingle wantHdrs asCsv rep pKeys =
unicodeStatement sql encodeUniformObjs decodeStandardMay True
where
sql = case rep of
None -> [qc|
WITH {sourceCTEName} AS ({mutateQuery})
SELECT '', 0, {noLocationF}, '' |]
HeadersOnly -> [qc|
WITH {sourceCTEName} AS ({mutateQuery})
SELECT {cols}
FROM (SELECT 1 FROM {sourceCTEName}) _postgrest_t |]
Full -> [qc|
WITH {sourceCTEName} AS ({mutateQuery})
SELECT {cols}
FROM ({selectQuery}) _postgrest_t |]
cols = intercalate ", " [
"'' AS total_result_set",
"pg_catalog.count(_postgrest_t) AS page_total",
if wantHdrs
then locationF pKeys
else noLocationF <> " AS header",
if rep == Full
then bodyF <> " AS body"
else "''"
]
bodyF
| asCsv = asCsvF
| wantSingle = asJsonSingleF
| otherwise = asJsonF
type ProcResults = (Maybe Int64, Int64, ByteString)
callProc :: QualifiedIdentifier -> JSON.Object -> SqlQuery -> SqlQuery -> NonnegRange ->
Bool -> Bool -> Bool -> Bool -> H.Query () (Maybe ProcResults)
callProc qi params selectQuery countQuery _ countTotal isSingle paramsAsJson asCsv =
unicodeStatement sql HE.unit decodeProc True
where
sql = [qc|
WITH {sourceCTEName} AS ({_callSql})
SELECT
{countResultF} AS total_result_set,
pg_catalog.count(_postgrest_t) AS page_total,
case
when pg_catalog.count(*) > 1 then
{bodyF}
else
coalesce(((array_agg(row_to_json(_postgrest_t)))[1]->{_procName})::character varying, {bodyF})
end as body
FROM ({selectQuery}) _postgrest_t;
|]
countResultF = if countTotal then "("<>countQuery<>")" else "null::bigint" :: Text
_args = if paramsAsJson
then insertableValueWithType "json" $ JSON.Object params
else intercalate "," $ map _assignment (HM.toList params)
_procName = pgFmtLit $ qiName qi
_assignment (n,v) = pgFmtIdent n <> ":=" <> insertableValue v
_callSql = [qc|select * from {fromQi qi}({_args}) |] :: Text
_countExpr = if countTotal
then [qc|(select pg_catalog.count(*) from {sourceCTEName})|]
else "null::bigint" :: Text
decodeProc = HD.maybeRow procRow
procRow = (,,) <$> HD.nullableValue HD.int8 <*> HD.value HD.int8
<*> HD.value HD.bytea
bodyF
| isSingle = asJsonSingleF
| asCsv = asCsvF
| otherwise = asJsonF
pgFmtIdent :: SqlFragment -> SqlFragment
pgFmtIdent x = "\"" <> replace "\"" "\"\"" (trimNullChars $ toS x) <> "\""
pgFmtLit :: SqlFragment -> SqlFragment
pgFmtLit x =
let trimmed = trimNullChars x
escaped = "'" <> replace "'" "''" trimmed <> "'"
slashed = replace "\\" "\\\\" escaped in
if "\\" `isInfixOf` escaped
then "E" <> slashed
else slashed
requestToCountQuery :: Schema -> DbRequest -> SqlQuery
requestToCountQuery _ (DbMutate _) = undefined
requestToCountQuery schema (DbRead (Node (Select _ _ conditions logic_ _ _, (mainTbl, _, _)) _)) =
unwords [
"SELECT pg_catalog.count(*)",
"FROM ", fromQi qi,
("WHERE " <> intercalate " AND " (map (pgFmtFilter qi) localConditions ++ map (pgFmtLogicTree qi) logic_))
`emptyOnFalse` (null conditions && null logic_)
]
where
qi = removeSourceCTESchema schema mainTbl
localFilter :: Filter -> Bool
localFilter Filter{operation=Operation{expr=(_, val)}} = case val of
VText _ -> True
VTextL _ -> True
VForeignKey _ _ -> False
localConditions = filter localFilter conditions
requestToQuery :: Schema -> Bool -> DbRequest -> SqlQuery
requestToQuery schema isParent (DbRead (Node (Select colSelects tbls conditions logic_ ord range, (nodeName, maybeRelation, _)) forest)) =
query
where
mainTbl = fromMaybe nodeName (tableName . relTable <$> maybeRelation)
qi = removeSourceCTESchema schema mainTbl
toQi = removeSourceCTESchema schema
query = unwords [
"SELECT ", intercalate ", " (map (pgFmtSelectItem qi) colSelects ++ selects),
"FROM ", intercalate ", " (map (fromQi . toQi) tbls),
unwords joins,
("WHERE " <> intercalate " AND " (map (pgFmtFilter qi) conditions ++ map (pgFmtLogicTree qi) logic_))
`emptyOnFalse` (null conditions && null logic_),
orderF (fromMaybe [] ord),
if isParent then "" else limitF range
]
orderF ts =
if null ts
then ""
else "ORDER BY " <> clause
where
clause = intercalate "," (map queryTerm ts)
queryTerm :: OrderTerm -> Text
queryTerm t = " "
<> toS (pgFmtField qi $ otTerm t) <> " "
<> maybe "" show (otDirection t) <> " "
<> maybe "" show (otNullOrder t) <> " "
(joins, selects) = foldr getQueryParts ([],[]) forest
getQueryParts :: Tree ReadNode -> ([SqlFragment], [SqlFragment]) -> ([SqlFragment], [SqlFragment])
getQueryParts (Node n@(_, (name, Just Relation{relType=Child,relTable=Table{tableName=table}}, alias)) forst) (j,s) = (j,sel:s)
where
sel = "COALESCE(("
<> "SELECT array_to_json(array_agg(row_to_json("<>pgFmtIdent table<>"))) "
<> "FROM (" <> subquery <> ") " <> pgFmtIdent table
<> "), '[]') AS " <> pgFmtIdent (fromMaybe name alias)
where subquery = requestToQuery schema False (DbRead (Node n forst))
getQueryParts (Node n@(_, (name, Just r@Relation{relType=Parent,relTable=Table{tableName=table}}, alias)) forst) (j,s) = (joi:j,sel:s)
where
node_name = fromMaybe name alias
local_table_name = table <> "_" <> node_name
replaceTableName localTableName (Filter a (Operation b (c, VForeignKey (QualifiedIdentifier "" _) d))) = Filter a (Operation b (c, VForeignKey (QualifiedIdentifier "" localTableName) d))
replaceTableName _ x = x
sel = "row_to_json(" <> pgFmtIdent local_table_name <> ".*) AS " <> pgFmtIdent node_name
joi = " LEFT OUTER JOIN ( " <> subquery <> " ) AS " <> pgFmtIdent local_table_name <>
" ON " <> intercalate " AND " ( map (pgFmtFilter qi . replaceTableName local_table_name) (getJoinConditions r) )
where subquery = requestToQuery schema True (DbRead (Node n forst))
getQueryParts (Node n@(_, (name, Just Relation{relType=Many,relTable=Table{tableName=table}}, alias)) forst) (j,s) = (j,sel:s)
where
sel = "COALESCE (("
<> "SELECT array_to_json(array_agg(row_to_json("<>pgFmtIdent table<>"))) "
<> "FROM (" <> subquery <> ") " <> pgFmtIdent table
<> "), '[]') AS " <> pgFmtIdent (fromMaybe name alias)
where subquery = requestToQuery schema False (DbRead (Node n forst))
getQueryParts _ _ = undefined
requestToQuery schema _ (DbMutate (Insert mainTbl (PayloadJSON rows) returnings)) =
insInto <> vals <> ret
where qi = QualifiedIdentifier schema mainTbl
cols = map pgFmtIdent $ fromMaybe [] (HM.keys <$> (rows V.!? 0))
colsString = intercalate ", " cols
insInto = unwords [ "INSERT INTO" , fromQi qi,
if T.null colsString then "" else "(" <> colsString <> ")"
]
vals = unwords $
if T.null colsString
then if V.null rows then ["SELECT null WHERE false"] else ["DEFAULT VALUES"]
else ["SELECT", colsString, "FROM json_populate_recordset(null::" , fromQi qi, ", $1)"]
ret = if null returnings
then ""
else unwords [" RETURNING ", intercalate ", " (map (pgFmtColumn qi) returnings)]
requestToQuery schema _ (DbMutate (Update mainTbl (PayloadJSON rows) conditions logic_ returnings)) =
case rows V.!? 0 of
Just obj ->
let assignments = map
(\(k,v) -> pgFmtIdent k <> "=" <> insertableValue v) $ HM.toList obj in
unwords [
"UPDATE ", fromQi qi,
" SET " <> intercalate "," assignments <> " ",
("WHERE " <> intercalate " AND " (map (pgFmtFilter qi) conditions ++ map (pgFmtLogicTree qi) logic_))
`emptyOnFalse` (null conditions && null logic_),
("RETURNING " <> intercalate ", " (map (pgFmtColumn qi) returnings)) `emptyOnFalse` null returnings
]
Nothing -> undefined
where
qi = QualifiedIdentifier schema mainTbl
requestToQuery schema _ (DbMutate (Delete mainTbl conditions logic_ returnings)) =
query
where
qi = QualifiedIdentifier schema mainTbl
query = unwords [
"DELETE FROM ", fromQi qi,
("WHERE " <> intercalate " AND " (map (pgFmtFilter qi) conditions ++ map (pgFmtLogicTree qi) logic_))
`emptyOnFalse` (null conditions && null logic_),
("RETURNING " <> intercalate ", " (map (pgFmtColumn qi) returnings)) `emptyOnFalse` null returnings
]
sourceCTEName :: SqlFragment
sourceCTEName = "pg_source"
removeSourceCTESchema :: Schema -> TableName -> QualifiedIdentifier
removeSourceCTESchema schema tbl = QualifiedIdentifier (if tbl == sourceCTEName then "" else schema) tbl
unquoted :: JSON.Value -> Text
unquoted (JSON.String t) = t
unquoted (JSON.Number n) =
toS $ formatScientific Fixed (if isInteger n then Just 0 else Nothing) n
unquoted (JSON.Bool b) = show b
unquoted v = toS $ JSON.encode v
asCsvF :: SqlFragment
asCsvF = asCsvHeaderF <> " || '\n' || " <> asCsvBodyF
where
asCsvHeaderF =
"(SELECT coalesce(string_agg(a.k, ','), '')" <>
" FROM (" <>
" SELECT json_object_keys(r)::TEXT as k" <>
" FROM ( " <>
" SELECT row_to_json(hh) as r from " <> sourceCTEName <> " as hh limit 1" <>
" ) s" <>
" ) a" <>
")"
asCsvBodyF = "coalesce(string_agg(substring(_postgrest_t::text, 2, length(_postgrest_t::text) - 2), '\n'), '')"
asJsonF :: SqlFragment
asJsonF = "coalesce(array_to_json(array_agg(row_to_json(_postgrest_t))), '[]')::character varying"
asJsonSingleF :: SqlFragment
asJsonSingleF = "coalesce(string_agg(row_to_json(_postgrest_t)::text, ','), '')::character varying "
asBinaryF :: FieldName -> SqlFragment
asBinaryF fieldName = "coalesce(string_agg(_postgrest_t." <> pgFmtIdent fieldName <> ", ''), '')"
locationF :: [Text] -> SqlFragment
locationF pKeys =
"(" <>
" WITH s AS (SELECT row_to_json(ss) as r from " <> sourceCTEName <> " as ss limit 1)" <>
" SELECT array_agg(json_data.key || '=' || coalesce('eq.' || json_data.value, 'is.null'))" <>
" FROM s, json_each_text(s.r) AS json_data" <>
(
if null pKeys
then ""
else " WHERE json_data.key IN ('" <> intercalate "','" pKeys <> "')"
) <> ")"
limitF :: NonnegRange -> SqlFragment
limitF r = if r == allRange
then ""
else "LIMIT " <> limit <> " OFFSET " <> offset
where
limit = maybe "ALL" show $ rangeLimit r
offset = show $ rangeOffset r
fromQi :: QualifiedIdentifier -> SqlFragment
fromQi t = (if s == "" then "" else pgFmtIdent s <> ".") <> pgFmtIdent n
where
n = qiName t
s = qiSchema t
getJoinConditions :: Relation -> [Filter]
getJoinConditions (Relation t cols ft fcs typ lt lc1 lc2) =
case typ of
Child -> zipWith (toFilter tN ftN) cols fcs
Parent -> zipWith (toFilter tN ftN) cols fcs
Many -> zipWith (toFilter tN ltN) cols (fromMaybe [] lc1) ++ zipWith (toFilter ftN ltN) fcs (fromMaybe [] lc2)
Root -> undefined
where
s = if typ == Parent then "" else tableSchema t
tN = tableName t
ftN = tableName ft
ltN = fromMaybe "" (tableName <$> lt)
toFilter :: Text -> Text -> Column -> Column -> Filter
toFilter tb ftb c fc = Filter (colName c, Nothing) (Operation False ("=", VForeignKey (QualifiedIdentifier s tb) (ForeignKey fc{colTable=(colTable fc){tableName=ftb}})))
unicodeStatement :: Text -> HE.Params a -> HD.Result b -> Bool -> H.Query a b
unicodeStatement = H.statement . T.encodeUtf8
emptyOnFalse :: Text -> Bool -> Text
emptyOnFalse val cond = if cond then "" else val
insertableValue :: JSON.Value -> SqlFragment
insertableValue JSON.Null = "null"
insertableValue v = (<> "::unknown") . pgFmtLit $ unquoted v
insertableValueWithType :: Text -> JSON.Value -> SqlFragment
insertableValueWithType t v =
pgFmtLit (unquoted v) <> "::" <> t
pgFmtColumn :: QualifiedIdentifier -> Text -> SqlFragment
pgFmtColumn table "*" = fromQi table <> ".*"
pgFmtColumn table c = fromQi table <> "." <> pgFmtIdent c
pgFmtField :: QualifiedIdentifier -> Field -> SqlFragment
pgFmtField table (c, jp) = pgFmtColumn table c <> pgFmtJsonPath jp
pgFmtSelectItem :: QualifiedIdentifier -> SelectItem -> SqlFragment
pgFmtSelectItem table (f@(_, jp), Nothing, alias) = pgFmtField table f <> pgFmtAs jp alias
pgFmtSelectItem table (f@(_, jp), Just cast, alias) = "CAST (" <> pgFmtField table f <> " AS " <> cast <> " )" <> pgFmtAs jp alias
pgFmtFilter :: QualifiedIdentifier -> Filter -> SqlFragment
pgFmtFilter table (Filter fld (Operation hasNot_ ex)) = notOp <> " " <> case ex of
(op, VText val) -> pgFmtFieldOp op <> " " <> case op of
"like" -> unknownLiteral (T.map star val)
"ilike" -> unknownLiteral (T.map star val)
"@@" -> "to_tsquery(" <> unknownLiteral val <> ") "
"is" -> whiteList val
"isnot" -> whiteList val
_ -> unknownLiteral val
(op, VTextL vals) -> pgFmtIn op vals
(op, VForeignKey fQi (ForeignKey Column{colTable=Table{tableName=fTableName}, colName=fColName})) ->
pgFmtField fQi fld <> " " <> sqlOperator op <> " " <> pgFmtColumn (removeSourceCTESchema (qiSchema fQi) fTableName) fColName
where
pgFmtFieldOp op = pgFmtField table fld <> " " <> sqlOperator op
sqlOperator o = HM.lookupDefault "=" o operators
notOp = if hasNot_ then "NOT" else ""
star c = if c == '*' then '%' else c
unknownLiteral = (<> "::unknown ") . pgFmtLit
whiteList :: Text -> SqlFragment
whiteList v = fromMaybe
(toS (pgFmtLit v) <> "::unknown ")
(find ((==) . toLower $ v) ["null","true","false"])
pgFmtIn :: Operator -> [Text] -> SqlFragment
pgFmtIn op vals =
let emptyValForIn o = (if "not" `isInfixOf` o then "NOT " else "")
<> pgFmtField table fld <> " = any('{}') " in
case T.null <$> headMay vals of
Just isNull -> if isNull && length vals == 1
then emptyValForIn op
else pgFmtFieldOp op <> "(" <> intercalate ", " (map unknownLiteral vals) <> ") "
Nothing -> emptyValForIn op
pgFmtLogicTree :: QualifiedIdentifier -> LogicTree -> SqlFragment
pgFmtLogicTree qi (Expr hasNot_ op lt rt) = notOp <> " (" <> pgFmtLogicTree qi lt <> " " <> show op <> " " <> pgFmtLogicTree qi rt <> ")"
where notOp = if hasNot_ then "NOT" else ""
pgFmtLogicTree qi (Stmnt flt) = pgFmtFilter qi flt
pgFmtJsonPath :: Maybe JsonPath -> SqlFragment
pgFmtJsonPath (Just [x]) = "->>" <> pgFmtLit x
pgFmtJsonPath (Just (x:xs)) = "->" <> pgFmtLit x <> pgFmtJsonPath ( Just xs )
pgFmtJsonPath _ = ""
pgFmtAs :: Maybe JsonPath -> Maybe Alias -> SqlFragment
pgFmtAs Nothing Nothing = ""
pgFmtAs (Just xx) Nothing = case lastMay xx of
Just alias -> " AS " <> pgFmtIdent alias
Nothing -> ""
pgFmtAs _ (Just alias) = " AS " <> pgFmtIdent alias
pgFmtEnvVar :: Text -> (Text, Text) -> SqlFragment
pgFmtEnvVar prefix (k, v) =
"set local " <> pgFmtIdent (prefix <> k) <> " = " <> pgFmtLit v <> ";"
trimNullChars :: Text -> Text
trimNullChars = T.takeWhile (/= '\x0')