{-| Module: Squeal.PostgreSQL.Session.Encode Description: encoding of statement parameters Copyright: (c) Eitan Chatav, 2019 Maintainer: eitan@morphism.tech Stability: experimental encoding of statement parameters -} {-# LANGUAGE AllowAmbiguousTypes , ConstraintKinds , DataKinds , DefaultSignatures , FlexibleContexts , FlexibleInstances , LambdaCase , MultiParamTypeClasses , PolyKinds , RankNTypes , ScopedTypeVariables , TypeApplications , TypeFamilies , TypeOperators , UndecidableInstances #-} module Squeal.PostgreSQL.Session.Encode ( -- * Encode Parameters EncodeParams (..) , genericParams , nilParams , (.*) , (*.) , aParam , appendParams -- * Encoding Classes , ToPG (..) , ToParam (..) , ToField (..) , ToArray (..) ) where import ByteString.StrictBuilder import Control.Monad import Control.Monad.Reader import Data.Bits import Data.ByteString as Strict (ByteString) import Data.ByteString.Lazy as Lazy (ByteString) import Data.Functor.Contravariant import Data.Int (Int16, Int32, Int64) import Data.Kind import Data.Scientific (Scientific) import Data.Text as Strict (Text) import Data.Text.Lazy as Lazy (Text) import Data.Time (Day, TimeOfDay, TimeZone, LocalTime, UTCTime, DiffTime) import Data.UUID.Types (UUID) import Data.Vector (Vector) import Data.Word (Word32) import Foreign.C.Types (CUInt(CUInt)) import GHC.TypeLits import Network.IP.Addr (NetAddr, IP) import PostgreSQL.Binary.Encoding import qualified Data.Aeson as Aeson import qualified Data.ByteString.Lazy as Lazy.ByteString import qualified Data.Text as Strict.Text import qualified Database.PostgreSQL.LibPQ as LibPQ import qualified Generics.SOP as SOP import qualified Generics.SOP.Record as SOP import Squeal.PostgreSQL.Expression.Range import Squeal.PostgreSQL.Session.Oid import Squeal.PostgreSQL.Type import Squeal.PostgreSQL.Type.Alias import Squeal.PostgreSQL.Type.List import Squeal.PostgreSQL.Type.PG import Squeal.PostgreSQL.Type.Schema -- $setup -- >>> import Squeal.PostgreSQL (connectdb, finish) -- | A `ToPG` constraint gives an encoding of a Haskell `Type` into -- into the binary format of a PostgreSQL `PGType`. class IsPG x => ToPG (db :: SchemasType) (x :: Type) where -- | >>> :set -XTypeApplications -XDataKinds -- >>> conn <- connectdb @'[] "host=localhost port=5432 dbname=exampledb" -- >>> runReaderT (toPG @'[] False) conn -- "\NUL" -- -- >>> runReaderT (toPG @'[] (0 :: Int16)) conn -- "\NUL\NUL" -- -- >>> runReaderT (toPG @'[] (0 :: Int32)) conn -- "\NUL\NUL\NUL\NUL" -- -- >>> :set -XMultiParamTypeClasses -XGeneralizedNewtypeDeriving -- >>> newtype UserId = UserId { getUserId :: Int64 } deriving newtype (IsPG, ToPG db) -- >>> runReaderT (toPG @'[] (UserId 0)) conn -- "\NUL\NUL\NUL\NUL\NUL\NUL\NUL\NUL" -- -- >>> finish conn toPG :: x -> ReaderT (SOP.K LibPQ.Connection db) IO Encoding instance ToPG db Bool where toPG = pure . bool instance ToPG db Int16 where toPG = pure . int2_int16 instance ToPG db Int32 where toPG = pure . int4_int32 instance ToPG db Int64 where toPG = pure . int8_int64 instance ToPG db Oid where toPG = pure . int4_word32 . getOid instance ToPG db Float where toPG = pure . float4 instance ToPG db Double where toPG = pure . float8 instance ToPG db Scientific where toPG = pure . numeric instance ToPG db Money where toPG = pure . int8_int64 . cents instance ToPG db UUID where toPG = pure . uuid instance ToPG db (NetAddr IP) where toPG = pure . inet instance ToPG db Char where toPG = pure . char_utf8 instance ToPG db Strict.Text where toPG = pure . text_strict instance ToPG db Lazy.Text where toPG = pure . text_lazy instance ToPG db String where toPG = pure . text_strict . Strict.Text.pack instance ToPG db Strict.ByteString where toPG = pure . bytea_strict instance ToPG db Lazy.ByteString where toPG = pure . bytea_lazy instance ToPG db Day where toPG = pure . date instance ToPG db TimeOfDay where toPG = pure . time_int instance ToPG db (TimeOfDay, TimeZone) where toPG = pure . timetz_int instance ToPG db LocalTime where toPG = pure . timestamp_int instance ToPG db UTCTime where toPG = pure . timestamptz_int instance ToPG db DiffTime where toPG = pure . interval_int instance ToPG db Aeson.Value where toPG = pure . json_ast instance Aeson.ToJSON x => ToPG db (Json x) where toPG = pure . json_bytes . Lazy.ByteString.toStrict . Aeson.encode . getJson instance Aeson.ToJSON x => ToPG db (Jsonb x) where toPG = pure . jsonb_bytes . Lazy.ByteString.toStrict . Aeson.encode . getJsonb instance (NullPG x ~ ty, ToArray db '[] ty x, OidOfNull db ty) => ToPG db (VarArray [x]) where toPG (VarArray arr) = do oid <- oidOfNull @db @ty let dims = [fromIntegral (length arr)] nulls = arrayNulls @db @'[] @ty @x payload <- dimArray foldM (arrayPayload @db @'[] @ty @x) arr return $ encodeArray 1 nulls oid dims payload instance (NullPG x ~ ty, ToArray db '[] ty x, OidOfNull db ty) => ToPG db (VarArray (Vector x)) where toPG (VarArray arr) = do oid <- oidOfNull @db @ty let dims = [fromIntegral (length arr)] nulls = arrayNulls @db @'[] @ty @x payload <- dimArray foldM (arrayPayload @db @'[] @ty @x) arr return $ encodeArray 1 nulls oid dims payload instance (ToArray db dims ty x, OidOfNull db ty) => ToPG db (FixArray x) where toPG (FixArray arr) = do oid <- oidOfNull @db @ty payload <- arrayPayload @db @dims @ty arr let dims = arrayDims @db @dims @ty @x nulls = arrayNulls @db @dims @ty @x ndims = fromIntegral (length dims) return $ encodeArray ndims nulls oid dims payload instance ( SOP.IsEnumType x , SOP.HasDatatypeInfo x , LabelsPG x ~ labels ) => ToPG db (Enumerated x) where toPG = let gshowConstructor :: NP SOP.ConstructorInfo xss -> SOP.SOP SOP.I xss -> String gshowConstructor Nil _ = "" gshowConstructor (constructor :* _) (SOP.SOP (SOP.Z _)) = SOP.constructorName constructor gshowConstructor (_ :* constructors) (SOP.SOP (SOP.S xs)) = gshowConstructor constructors (SOP.SOP xs) in pure . text_strict . Strict.Text.pack . gshowConstructor (SOP.constructorInfo (SOP.datatypeInfo (SOP.Proxy @x))) . SOP.from . getEnumerated instance ( SOP.SListI fields , SOP.IsRecord x xs , SOP.AllZip (ToField db) fields xs , SOP.All (OidOfField db) fields , RowPG x ~ fields ) => ToPG db (Composite x) where toPG (Composite x) = do let compositeSize = int4_int32 $ fromIntegral $ SOP.lengthSList $ SOP.Proxy @xs each :: OidOfField db field => SOP.K (Maybe Encoding) field -> ReaderT (SOP.K LibPQ.Connection db) IO Encoding each (SOP.K field :: SOP.K (Maybe Encoding) field) = do oid <- getOid <$> oidOfField @db @field return $ int4_word32 oid <> maybe null4 sized field fields :: NP (SOP.K (Maybe Encoding)) fields <- hctransverse (SOP.Proxy @(ToField db)) (toField @db) (SOP.toRecord x) compositePayload <- hcfoldMapM (SOP.Proxy @(OidOfField db)) each fields return $ compositeSize <> compositePayload instance ToPG db x => ToPG db (Range x) where toPG r = do payload <- case r of Empty -> return mempty NonEmpty lower upper -> (<>) <$> putBound lower <*> putBound upper return $ word8 (setFlags r 0) <> payload where putBound = \case Infinite -> return mempty Closed value -> sized <$> toPG @db value Open value -> sized <$> toPG @db value setFlags = \case Empty -> (`setBit` 0) NonEmpty lower upper -> setLowerFlags lower . setUpperFlags upper setLowerFlags = \case Infinite -> (`setBit` 3) Closed _ -> (`setBit` 1) Open _ -> id setUpperFlags = \case Infinite -> (`setBit` 4) Closed _ -> (`setBit` 2) Open _ -> id -- | A `ToParam` constraint gives an encoding of a Haskell `Type` into -- into the binary format of a PostgreSQL `NullType`. -- You should not define instances for `ToParam`, -- just use the provided instances. class ToParam (db :: SchemasType) (ty :: NullType) (x :: Type) where toParam :: x -> ReaderT (SOP.K LibPQ.Connection db) IO (Maybe Encoding) instance (ToPG db x, pg ~ PG x) => ToParam db ('NotNull pg) x where toParam = fmap Just . toPG @db instance (ToPG db x, pg ~ PG x) => ToParam db ('Null pg) (Maybe x) where toParam = maybe (pure Nothing) (fmap Just . toPG @db) -- | A `ToField` constraint lifts the `ToPG` parser -- to an encoding of a @(Symbol, Type)@ to a @(Symbol, NullityType)@, -- encoding `Null`s to `Maybe`s. You should not define instances for -- `ToField`, just use the provided instances. class ToField (db :: SchemasType) (field :: (Symbol, NullType)) (x :: (Symbol, Type)) where toField :: SOP.P x -> ReaderT (SOP.K LibPQ.Connection db) IO (SOP.K (Maybe Encoding) field) instance (fld0 ~ fld1, ToParam db ty x) => ToField db (fld0 ::: ty) (fld1 ::: x) where toField (SOP.P x) = SOP.K <$> toParam @db @ty x -- | A `ToArray` constraint gives an encoding of a Haskell `Type` -- into the binary format of a PostgreSQL fixed-length array. -- You should not define instances for -- `ToArray`, just use the provided instances. class ToArray (db :: SchemasType) (dims :: [Nat]) (ty :: NullType) (x :: Type) where arrayPayload :: x -> ReaderT (SOP.K LibPQ.Connection db) IO Encoding arrayDims :: [Int32] arrayNulls :: Bool instance (ToPG db x, pg ~ PG x) => ToArray db '[] ('NotNull pg) x where arrayPayload = fmap sized . toPG @db @x arrayDims = [] arrayNulls = False instance (ToPG db x, pg ~ PG x) => ToArray db '[] ('Null pg) (Maybe x) where arrayPayload = maybe (pure null4) (fmap sized . toPG @db @x) arrayDims = [] arrayNulls = True instance ( SOP.IsProductType tuple xs , Length xs ~ dim , SOP.All ((~) x) xs , ToArray db dims ty x , KnownNat dim ) => ToArray db (dim ': dims) ty tuple where arrayPayload = dimArray foldlNP (arrayPayload @db @dims @ty @x) . SOP.unZ . SOP.unSOP . SOP.from arrayDims = fromIntegral (natVal (SOP.Proxy @dim)) : arrayDims @db @dims @ty @x arrayNulls = arrayNulls @db @dims @ty @x foldlNP :: (SOP.All ((~) x) xs, Monad m) => (z -> x -> m z) -> z -> NP SOP.I xs -> m z foldlNP f z = \case Nil -> pure z SOP.I x :* xs -> do z' <- f z x foldlNP f z' xs {- | `EncodeParams` describes an encoding of a Haskell `Type` into a list of parameter `NullType`s. >>> conn <- connectdb @'[] "host=localhost port=5432 dbname=exampledb" >>> :{ let encode :: EncodeParams '[] '[ 'NotNull 'PGint2, 'NotNull ('PGchar 1), 'NotNull 'PGtext] (Int16, (Char, String)) encode = fst .* fst.snd *. snd.snd in runReaderT (runEncodeParams encode (1,('a',"foo"))) conn :} K (Just "\NUL\SOH") :* K (Just "a") :* K (Just "foo") :* Nil >>> finish conn -} newtype EncodeParams (db :: SchemasType) (tys :: [NullType]) (x :: Type) = EncodeParams { runEncodeParams :: x -> ReaderT (SOP.K LibPQ.Connection db) IO (NP (SOP.K (Maybe Encoding)) tys) } instance Contravariant (EncodeParams db tys) where contramap f (EncodeParams g) = EncodeParams (g . f) {- | Parameter encoding for `SOP.Generic` tuples and records. >>> import qualified GHC.Generics as GHC >>> import qualified Generics.SOP as SOP >>> data Two = Two Int16 String deriving (GHC.Generic, SOP.Generic) >>> conn <- connectdb @'[] "host=localhost port=5432 dbname=exampledb" >>> :{ let encode :: EncodeParams '[] '[ 'NotNull 'PGint2, 'NotNull 'PGtext] Two encode = genericParams in runReaderT (runEncodeParams encode (Two 2 "two")) conn :} K (Just "\NUL\STX") :* K (Just "two") :* Nil >>> :{ let encode :: EncodeParams '[] '[ 'NotNull 'PGint2, 'NotNull 'PGtext] (Int16, String) encode = genericParams in runReaderT (runEncodeParams encode (2, "two")) conn :} K (Just "\NUL\STX") :* K (Just "two") :* Nil >>> finish conn -} genericParams :: forall db params x xs. ( SOP.IsProductType x xs , SOP.AllZip (ToParam db) params xs ) => EncodeParams db params x genericParams = EncodeParams $ hctransverse (SOP.Proxy @(ToParam db)) encodeNullParam . SOP.unZ . SOP.unSOP . SOP.from where encodeNullParam :: forall ty y. ToParam db ty y => SOP.I y -> ReaderT (SOP.K LibPQ.Connection db) IO (SOP.K (Maybe Encoding) ty) encodeNullParam = fmap SOP.K . toParam @db @ty . SOP.unI -- | Encode 0 parameters. nilParams :: EncodeParams db '[] x nilParams = EncodeParams $ \ _ -> pure Nil {- | Cons a parameter encoding. >>> conn <- connectdb @'[] "host=localhost port=5432 dbname=exampledb" >>> :{ let encode :: EncodeParams '[] '[ 'Null 'PGint4, 'NotNull 'PGtext] (Maybe Int32, String) encode = fst .* snd .* nilParams in runReaderT (runEncodeParams encode (Nothing, "foo")) conn :} K Nothing :* K (Just "foo") :* Nil >>> finish conn -} (.*) :: forall db x0 ty x tys. (ToParam db ty x0) => (x -> x0) -- ^ head -> EncodeParams db tys x -- ^ tail -> EncodeParams db (ty ': tys) x f .* EncodeParams params = EncodeParams $ \x -> (:*) <$> (SOP.K <$> toParam @db @ty (f x)) <*> params x infixr 5 .* {- | End a parameter encoding. >>> conn <- connectdb @'[] "host=localhost port=5432 dbname=exampledb" >>> :{ let encode :: EncodeParams '[] '[ 'Null 'PGint4, 'NotNull 'PGtext, 'NotNull ('PGchar 1)] (Maybe Int32, String, Char) encode = (\(x,_,_) -> x) .* (\(_,y,_) -> y) *. (\(_,_,z) -> z) in runReaderT (runEncodeParams encode (Nothing, "foo", 'z')) conn :} K Nothing :* K (Just "foo") :* K (Just "z") :* Nil >>> finish conn -} (*.) :: forall db x x0 ty0 x1 ty1 . (ToParam db ty0 x0, ToParam db ty1 x1) => (x -> x0) -- ^ second to last -> (x -> x1) -- ^ last -> EncodeParams db '[ty0, ty1] x f *. g = f .* g .* nilParams infixl 8 *. {- | Encode 1 parameter. >>> conn <- connectdb @'[] "host=localhost port=5432 dbname=exampledb" >>> :{ let encode :: EncodeParams '[] '[ 'NotNull 'PGint4] Int32 encode = aParam in runReaderT (runEncodeParams encode 1776) conn :} K (Just "\NUL\NUL\ACK\240") :* Nil >>> finish conn -} aParam :: forall db x. ToParam db (NullPG x) x => EncodeParams db '[NullPG x] x aParam = EncodeParams $ fmap (\param -> SOP.K param :* Nil) . toParam @db @(NullPG x) {- | Append parameter encodings. >>> conn <- connectdb @'[] "host=localhost port=5432 dbname=exampledb" >>> :{ let encode :: EncodeParams '[] '[ 'NotNull 'PGint4, 'NotNull 'PGint2] (Int32, Int16) encode = contramap fst aParam `appendParams` contramap snd aParam in runReaderT (runEncodeParams encode (1776, 2)) conn :} K (Just "\NUL\NUL\ACK\240") :* K (Just "\NUL\STX") :* Nil >>> finish conn -} appendParams :: EncodeParams db params0 x -- ^ left -> EncodeParams db params1 x -- ^ right -> EncodeParams db (Join params0 params1) x appendParams encode0 encode1 = EncodeParams $ \x -> also <$> runEncodeParams encode1 x <*> runEncodeParams encode0 x getOid :: LibPQ.Oid -> Word32 getOid (LibPQ.Oid (CUInt oid)) = oid encodeArray :: Int32 -> Bool -> LibPQ.Oid -> [Int32] -> Encoding -> Encoding encodeArray ndim nulls oid dimensions payload = mconcat [ int4_int32 ndim , if nulls then true4 else false4 , int4_word32 (getOid oid) , foldMap (\dimension -> int4_int32 dimension <> true4) dimensions , payload ] dimArray :: Functor m => (forall b. (b -> a -> m b) -> b -> c -> m b) -> (a -> m Encoding) -> c -> m Encoding dimArray folder elementArray = folder step mempty where step builder element = (builder <>) <$> elementArray element null4, true4, false4 :: Encoding null4 = int4_int32 (-1) true4 = int4_word32 1 false4 = int4_word32 0 sized :: Encoding -> Encoding sized bs = int4_int32 (fromIntegral (builderLength bs)) <> bs hctransverse :: (SOP.AllZip c ys xs, Applicative m) => SOP.Proxy c -> (forall y x. c y x => f x -> m (g y)) -> NP f xs -> m (NP g ys) hctransverse c f = \case Nil -> pure Nil x :* xs -> (:*) <$> f x <*> hctransverse c f xs hcfoldMapM :: (Monoid r, Applicative m, SOP.All c xs) => SOP.Proxy c -> (forall x. c x => f x -> m r) -> NP f xs -> m r hcfoldMapM c f = \case Nil -> pure mempty x :* xs -> (<>) <$> f x <*> hcfoldMapM c f xs