{-# LANGUAGE
AllowAmbiguousTypes
, DataKinds
, DerivingStrategies
, FlexibleContexts
, FlexibleInstances
, FunctionalDependencies
, GeneralizedNewtypeDeriving
, LambdaCase
, MultiParamTypeClasses
, OverloadedStrings
, PolyKinds
, ScopedTypeVariables
, TypeApplications
, TypeFamilies
, TypeOperators
, UndecidableInstances
#-}
module Squeal.PostgreSQL.Session.Decode
(
FromPG (..)
, devalue
, rowValue
, enumValue
, DecodeRow (..)
, decodeRow
, runDecodeRow
, genericRow
, appendRows
, consRow
, FromValue (..)
, FromField (..)
, FromArray (..)
, StateT (..)
, ExceptT (..)
) where
import BinaryParser
import Control.Applicative
import Control.Arrow
import Control.Monad
import Control.Monad.Fail
import Control.Monad.Except
import Control.Monad.Reader
import Control.Monad.State
import Control.Monad.Trans.Maybe
import Data.Bits
import Data.Int (Int16, Int32, Int64)
import Data.Kind
import Data.Scientific (Scientific)
import Data.String (fromString)
import Data.Text (Text)
import Data.Time (Day, TimeOfDay, TimeZone, LocalTime, UTCTime, DiffTime)
import Data.UUID.Types (UUID)
import Data.Vector (Vector)
import Database.PostgreSQL.LibPQ (Oid(Oid))
import GHC.OverloadedLabels
import GHC.TypeLits
import Network.IP.Addr (NetAddr, IP)
import PostgreSQL.Binary.Decoding hiding (Composite)
import Unsafe.Coerce
import qualified Data.Aeson as Aeson
import qualified Data.ByteString.Lazy as Lazy (ByteString)
import qualified Data.ByteString as Strict (ByteString)
import qualified Data.Text.Lazy as Lazy (Text)
import qualified Data.Text as Strict (Text)
import qualified Data.Text as Strict.Text
import qualified Data.Vector as Vector
import qualified Generics.SOP as SOP
import qualified Generics.SOP.Record as SOP
import Squeal.PostgreSQL.Expression.Range
import Squeal.PostgreSQL.Type
import Squeal.PostgreSQL.Type.Alias
import Squeal.PostgreSQL.Type.List
import Squeal.PostgreSQL.Type.PG
import Squeal.PostgreSQL.Type.Schema
devalue :: Value x -> StateT Strict.ByteString (Except Strict.Text) x
devalue = unsafeCoerce
revalue :: StateT Strict.ByteString (Except Strict.Text) x -> Value x
revalue = unsafeCoerce
rowValue
:: (PG y ~ 'PGcomposite row, SOP.SListI row)
=> DecodeRow row y
-> StateT Strict.ByteString (Except Strict.Text) y
rowValue decoder = devalue $
let
comp = valueParser $ do
unitOfSize 4
SOP.hsequence' $ SOP.hpure $ SOP.Comp $ do
unitOfSize 4
len :: Int32 <- sized 4 int
if len == -1
then return (SOP.K Nothing)
else SOP.K . Just <$> bytesOfSize (fromIntegral len)
in fn (runDecodeRow decoder <=< comp)
class IsPG y => FromPG y where
fromPG :: StateT Strict.ByteString (Except Strict.Text) y
instance FromPG Bool where
fromPG = devalue bool
instance FromPG Int16 where
fromPG = devalue int
instance FromPG Int32 where
fromPG = devalue int
instance FromPG Int64 where
fromPG = devalue int
instance FromPG Oid where
fromPG = devalue $ Oid <$> int
instance FromPG Float where
fromPG = devalue float4
instance FromPG Double where
fromPG = devalue float8
instance FromPG Scientific where
fromPG = devalue numeric
instance FromPG Money where
fromPG = devalue $ Money <$> int
instance FromPG UUID where
fromPG = devalue uuid
instance FromPG (NetAddr IP) where
fromPG = devalue inet
instance FromPG Char where
fromPG = devalue char
instance FromPG Strict.Text where
fromPG = devalue text_strict
instance FromPG Lazy.Text where
fromPG = devalue text_lazy
instance FromPG String where
fromPG = devalue $ Strict.Text.unpack <$> text_strict
instance FromPG Strict.ByteString where
fromPG = devalue bytea_strict
instance FromPG Lazy.ByteString where
fromPG = devalue bytea_lazy
instance KnownNat n => FromPG (VarChar n) where
fromPG = devalue $ text_strict >>= \t ->
case varChar t of
Nothing -> throwError $ Strict.Text.pack $ concat
[ "Source for VarChar has wrong length"
, "; expected length "
, show (natVal (SOP.Proxy @n))
, ", actual length "
, show (Strict.Text.length t)
, "."
]
Just x -> pure x
instance KnownNat n => FromPG (FixChar n) where
fromPG = devalue $ text_strict >>= \t ->
case fixChar t of
Nothing -> throwError $ Strict.Text.pack $ concat
[ "Source for FixChar has wrong length"
, "; expected length "
, show (natVal (SOP.Proxy @n))
, ", actual length "
, show (Strict.Text.length t)
, "."
]
Just x -> pure x
instance FromPG Day where
fromPG = devalue date
instance FromPG TimeOfDay where
fromPG = devalue time_int
instance FromPG (TimeOfDay, TimeZone) where
fromPG = devalue timetz_int
instance FromPG LocalTime where
fromPG = devalue timestamp_int
instance FromPG UTCTime where
fromPG = devalue timestamptz_int
instance FromPG DiffTime where
fromPG = devalue interval_int
instance FromPG Aeson.Value where
fromPG = devalue json_ast
instance Aeson.FromJSON x => FromPG (Json x) where
fromPG = devalue $ Json <$>
json_bytes (left Strict.Text.pack . Aeson.eitherDecodeStrict)
instance Aeson.FromJSON x => FromPG (Jsonb x) where
fromPG = devalue $ Jsonb <$>
jsonb_bytes (left Strict.Text.pack . Aeson.eitherDecodeStrict)
instance (FromArray '[] ty y, ty ~ NullPG y)
=> FromPG (VarArray (Vector y)) where
fromPG =
let
rep n x = VarArray <$> Vector.replicateM n x
in
devalue . array $ dimensionArray rep
(fromArray @'[] @(NullPG y))
instance (FromArray '[] ty y, ty ~ NullPG y)
=> FromPG (VarArray [y]) where
fromPG =
let
rep n x = VarArray <$> replicateM n x
in
devalue . array $ dimensionArray rep
(fromArray @'[] @(NullPG y))
instance FromArray dims ty y => FromPG (FixArray y) where
fromPG = devalue $ FixArray <$> array (fromArray @dims @ty @y)
instance
( SOP.IsEnumType y
, SOP.HasDatatypeInfo y
, LabelsPG y ~ labels
) => FromPG (Enumerated y) where
fromPG =
let
greadConstructor
:: SOP.All ((~) '[]) xss
=> NP SOP.ConstructorInfo xss
-> String
-> Maybe (SOP.SOP SOP.I xss)
greadConstructor Nil _ = Nothing
greadConstructor (constructor :* constructors) name =
if name == SOP.constructorName constructor
then Just (SOP.SOP (SOP.Z Nil))
else SOP.SOP . SOP.S . SOP.unSOP <$>
greadConstructor constructors name
in
devalue
$ fmap Enumerated
. enum
$ fmap SOP.to
. greadConstructor
(SOP.constructorInfo (SOP.datatypeInfo (SOP.Proxy @y)))
. Strict.Text.unpack
instance
( SOP.IsRecord y ys
, SOP.AllZip FromField row ys
, RowPG y ~ row
) => FromPG (Composite y) where
fromPG = rowValue (Composite <$> genericRow)
instance FromPG y => FromPG (Range y) where
fromPG = devalue $ do
flag <- byte
if testBit flag 0 then return Empty else do
lower <-
if testBit flag 3
then return Infinite
else do
len <- sized 4 int
l <- sized len (revalue fromPG)
return $ if testBit flag 1 then Closed l else Open l
upper <-
if testBit flag 4
then return Infinite
else do
len <- sized 4 int
l <- sized len (revalue fromPG)
return $ if testBit flag 2 then Closed l else Open l
return $ NonEmpty lower upper
class FromValue (ty :: NullType) (y :: Type) where
fromValue :: Maybe Strict.ByteString -> Either Strict.Text y
instance (FromPG y, pg ~ PG y) => FromValue ('NotNull pg) y where
fromValue = \case
Nothing -> throwError "fromField: saw NULL when expecting NOT NULL"
Just bytestring -> valueParser (revalue fromPG) bytestring
instance (FromPG y, pg ~ PG y) => FromValue ('Null pg) (Maybe y) where
fromValue = \case
Nothing -> return Nothing
Just bytestring -> fmap Just $ valueParser (revalue fromPG) bytestring
class FromField (field :: (Symbol, NullType)) (y :: (Symbol, Type)) where
fromField :: Maybe Strict.ByteString -> Either Strict.Text (SOP.P y)
instance (FromValue ty y, fld0 ~ fld1)
=> FromField (fld0 ::: ty) (fld1 ::: y) where
fromField = fmap SOP.P . fromValue @ty
class FromArray (dims :: [Nat]) (ty :: NullType) (y :: Type) where
fromArray :: Array y
instance (FromPG y, pg ~ PG y) => FromArray '[] ('NotNull pg) y where
fromArray = valueArray (revalue fromPG)
instance (FromPG y, pg ~ PG y) => FromArray '[] ('Null pg) (Maybe y) where
fromArray = nullableValueArray (revalue fromPG)
instance
( SOP.IsProductType product ys
, Length ys ~ dim
, SOP.All ((~) y) ys
, FromArray dims ty y )
=> FromArray (dim ': dims) ty product where
fromArray =
let
rep _ = fmap (SOP.to . SOP.SOP . SOP.Z) . replicateMN
in
dimensionArray rep (fromArray @dims @ty @y)
replicateMN
:: forall x xs m. (SOP.All ((~) x) xs, Monad m, SOP.SListI xs)
=> m x -> m (SOP.NP SOP.I xs)
replicateMN mx = SOP.hsequence' $
SOP.hcpure (SOP.Proxy :: SOP.Proxy ((~) x)) (SOP.Comp (SOP.I <$> mx))
newtype DecodeRow (row :: RowType) (y :: Type) = DecodeRow
{ unDecodeRow :: ReaderT
(SOP.NP (SOP.K (Maybe Strict.ByteString)) row) (Except Strict.Text) y }
deriving newtype
( Functor
, Applicative
, Alternative
, Monad
, MonadPlus
, MonadError Strict.Text )
instance MonadFail (DecodeRow row) where
fail = throwError . fromString
runDecodeRow
:: DecodeRow row y
-> SOP.NP (SOP.K (Maybe Strict.ByteString)) row
-> Either Strict.Text y
runDecodeRow = fmap runExcept . runReaderT . unDecodeRow
appendRows
:: SOP.SListI left
=> (l -> r -> z)
-> DecodeRow left l
-> DecodeRow right r
-> DecodeRow (Join left right) z
appendRows f decL decR = decodeRow $ \row -> case disjoin row of
(rowL, rowR) -> f <$> runDecodeRow decL rowL <*> runDecodeRow decR rowR
consRow
:: FromValue head h
=> (h -> t -> z)
-> Alias col
-> DecodeRow tail t
-> DecodeRow (col ::: head ': tail) z
consRow f _ dec = decodeRow $ \case
(SOP.K h :: SOP.K (Maybe Strict.ByteString) (col ::: head)) :* t
-> f <$> fromValue @head h <*> runDecodeRow dec t
decodeRow
:: (SOP.NP (SOP.K (Maybe Strict.ByteString)) row -> Either Strict.Text y)
-> DecodeRow row y
decodeRow dec = DecodeRow . ReaderT $ liftEither . dec
instance {-# OVERLAPPING #-} FromValue ty y
=> IsLabel fld (DecodeRow (fld ::: ty ': row) y) where
fromLabel = decodeRow $ \(SOP.K b SOP.:* _) ->
fromValue @ty b
instance {-# OVERLAPPABLE #-} IsLabel fld (DecodeRow row y)
=> IsLabel fld (DecodeRow (field ': row) y) where
fromLabel = decodeRow $ \(_ SOP.:* bs) ->
runDecodeRow (fromLabel @fld) bs
instance {-# OVERLAPPING #-} FromValue ty (Maybe y)
=> IsLabel fld (MaybeT (DecodeRow (fld ::: ty ': row)) y) where
fromLabel = MaybeT . decodeRow $ \(SOP.K b SOP.:* _) ->
fromValue @ty b
instance {-# OVERLAPPABLE #-} IsLabel fld (MaybeT (DecodeRow row) y)
=> IsLabel fld (MaybeT (DecodeRow (field ': row)) y) where
fromLabel = MaybeT . decodeRow $ \(_ SOP.:* bs) ->
runDecodeRow (runMaybeT (fromLabel @fld)) bs
genericRow :: forall row y ys.
( SOP.IsRecord y ys
, SOP.AllZip FromField row ys
) => DecodeRow row y
genericRow
= DecodeRow
. ReaderT
$ fmap SOP.fromRecord
. SOP.hsequence'
. SOP.htrans (SOP.Proxy @FromField) (SOP.Comp . runField)
runField
:: forall ty y. FromField ty y
=> SOP.K (Maybe Strict.ByteString) ty
-> Except Strict.Text (SOP.P y)
runField = liftEither . fromField @ty . SOP.unK
enumValue
:: (SOP.All KnownSymbol labels, PG y ~ 'PGenum labels)
=> NP (SOP.K y) labels
-> StateT Strict.ByteString (Except Strict.Text) y
enumValue = devalue . enum . labels
where
labels
:: SOP.All KnownSymbol labels
=> NP (SOP.K y) labels
-> Text -> Maybe y
labels = \case
Nil -> \_ -> Nothing
((y :: SOP.K y label) :* ys) -> \ str ->
if str == fromString (symbolVal (SOP.Proxy @label))
then Just (SOP.unK y)
else labels ys str