{-# language GADTs #-}
{-# language NamedFieldPuns #-}
{-# language ScopedTypeVariables #-}
{-# language StandaloneKindSignatures #-}
{-# language TypeApplications #-}

module Rel8.Statement.Update
  ( Update(..)
  , update
  )
where

-- base
import Control.Exception ( throwIO )
import Data.Kind ( Type )
import Prelude

-- hasql
import Hasql.Connection ( Connection )
import qualified Hasql.Decoders as Hasql
import qualified Hasql.Encoders as Hasql
import qualified Hasql.Session as Hasql
import qualified Hasql.Statement as Hasql

-- opaleye
import qualified Opaleye.Internal.Manipulation as Opaleye

-- profunctors
import Data.Profunctor ( lmap )

-- rel8
import Rel8.Expr ( Expr )
import Rel8.Expr.Opaleye ( toColumn, toPrimExpr )
import Rel8.Schema.Name ( Selects )
import Rel8.Schema.Table ( TableSchema )
import Rel8.Statement.Returning ( Returning( Projection, NumberOfRowsAffected ) )
import Rel8.Table ( fromColumns, toColumns )
import Rel8.Table.Insert ( toInsert )
import Rel8.Table.Opaleye ( castTable, table, unpackspec )
import Rel8.Table.Serialize ( Serializable, parse )

-- text
import qualified Data.Text as Text
import Data.Text.Encoding ( encodeUtf8 )


-- | The constituent parts of an @UPDATE@ statement.
type Update :: Type -> Type
data Update a where
  Update :: Selects names exprs =>
    { ()
target :: TableSchema names
      -- ^ Which table to update.
    , ()
set :: exprs -> exprs
      -- ^ How to update each selected row.
    , ()
updateWhere :: exprs -> Expr Bool
      -- ^ Which rows to select for update.
    , ()
returning :: Returning names a
      -- ^ What to return from the @UPDATE@ statement.
    }
    -> Update a


-- | Run an @UPDATE@ statement.
update :: Connection -> Update a -> IO a
update :: Connection -> Update a -> IO a
update Connection
c Update {TableSchema names
target :: TableSchema names
target :: ()
target, exprs -> exprs
set :: exprs -> exprs
set :: ()
set, exprs -> Expr Bool
updateWhere :: exprs -> Expr Bool
updateWhere :: ()
updateWhere, Returning names a
returning :: Returning names a
returning :: ()
returning} =
  case Returning names a
returning of
    Returning names a
NumberOfRowsAffected -> Session Int64 -> Connection -> IO (Either QueryError Int64)
forall a. Session a -> Connection -> IO (Either QueryError a)
Hasql.run Session Int64
session Connection
c IO (Either QueryError Int64)
-> (Either QueryError Int64 -> IO Int64) -> IO Int64
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (QueryError -> IO Int64)
-> (Int64 -> IO Int64) -> Either QueryError Int64 -> IO Int64
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either QueryError -> IO Int64
forall e a. Exception e => e -> IO a
throwIO Int64 -> IO Int64
forall (f :: * -> *) a. Applicative f => a -> f a
pure
      where
        session :: Session Int64
session = () -> Statement () Int64 -> Session Int64
forall params result.
params -> Statement params result -> Session result
Hasql.statement () Statement () Int64
statement
        statement :: Statement () Int64
statement = ByteString
-> Params () -> Result Int64 -> Bool -> Statement () Int64
forall a b.
ByteString -> Params a -> Result b -> Bool -> Statement a b
Hasql.Statement ByteString
bytes Params ()
params Result Int64
decode Bool
prepare
        bytes :: ByteString
bytes = Text -> ByteString
encodeUtf8 (Text -> ByteString) -> Text -> ByteString
forall a b. (a -> b) -> a -> b
$ String -> Text
Text.pack String
sql
        params :: Params ()
params = Params ()
Hasql.noParams
        decode :: Result Int64
decode = Result Int64
Hasql.rowsAffected
        prepare :: Bool
prepare = Bool
False
        sql :: String
sql = Table (Columns exprs (Col Expr)) (Columns exprs (Col Expr))
-> (Columns exprs (Col Expr) -> Columns exprs (Col Expr))
-> (Columns exprs (Col Expr) -> Column SqlBool)
-> String
forall columnsW columnsR.
Table columnsW columnsR
-> (columnsR -> columnsW) -> (columnsR -> Column SqlBool) -> String
Opaleye.arrangeUpdateSql Table (Columns exprs (Col Expr)) (Columns exprs (Col Expr))
target' Columns exprs (Col Expr) -> Columns exprs (Col Expr)
set' Columns exprs (Col Expr) -> Column SqlBool
where'
          where
            target' :: Table (Columns exprs (Col Expr)) (Columns exprs (Col Expr))
target' = (Columns exprs (Col Expr) -> Columns exprs (Col Insert))
-> Table (Columns exprs (Col Insert)) (Columns exprs (Col Expr))
-> Table (Columns exprs (Col Expr)) (Columns exprs (Col Expr))
forall (p :: * -> * -> *) a b c.
Profunctor p =>
(a -> b) -> p b c -> p a c
lmap Columns exprs (Col Expr) -> Columns exprs (Col Insert)
forall exprs inserts. Inserts exprs inserts => exprs -> inserts
toInsert (Table (Columns exprs (Col Insert)) (Columns exprs (Col Expr))
 -> Table (Columns exprs (Col Expr)) (Columns exprs (Col Expr)))
-> Table (Columns exprs (Col Insert)) (Columns exprs (Col Expr))
-> Table (Columns exprs (Col Expr)) (Columns exprs (Col Expr))
forall a b. (a -> b) -> a -> b
$ TableSchema (Columns exprs (Col Name))
-> Table (Columns exprs (Col Insert)) (Columns exprs (Col Expr))
forall names exprs inserts.
(Selects names exprs, Inserts exprs inserts) =>
TableSchema names -> Table inserts exprs
table (TableSchema (Columns exprs (Col Name))
 -> Table (Columns exprs (Col Insert)) (Columns exprs (Col Expr)))
-> TableSchema (Columns exprs (Col Name))
-> Table (Columns exprs (Col Insert)) (Columns exprs (Col Expr))
forall a b. (a -> b) -> a -> b
$ names -> Columns exprs (Col Name)
forall (context :: Context) a.
Table context a =>
a -> Columns a (Col context)
toColumns (names -> Columns exprs (Col Name))
-> TableSchema names -> TableSchema (Columns exprs (Col Name))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TableSchema names
target
            set' :: Columns exprs (Col Expr) -> Columns exprs (Col Expr)
set' = exprs -> Columns exprs (Col Expr)
forall (context :: Context) a.
Table context a =>
a -> Columns a (Col context)
toColumns (exprs -> Columns exprs (Col Expr))
-> (Columns exprs (Col Expr) -> exprs)
-> Columns exprs (Col Expr)
-> Columns exprs (Col Expr)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. exprs -> exprs
set (exprs -> exprs)
-> (Columns exprs (Col Expr) -> exprs)
-> Columns exprs (Col Expr)
-> exprs
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Columns exprs (Col Expr) -> exprs
forall (context :: Context) a.
Table context a =>
Columns a (Col context) -> a
fromColumns
            where' :: Columns exprs (Col Expr) -> Column SqlBool
where' = PrimExpr -> Column SqlBool
forall b. PrimExpr -> Column b
toColumn (PrimExpr -> Column SqlBool)
-> (Columns exprs (Col Expr) -> PrimExpr)
-> Columns exprs (Col Expr)
-> Column SqlBool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Expr Bool -> PrimExpr
forall a. Expr a -> PrimExpr
toPrimExpr (Expr Bool -> PrimExpr)
-> (Columns exprs (Col Expr) -> Expr Bool)
-> Columns exprs (Col Expr)
-> PrimExpr
forall b c a. (b -> c) -> (a -> b) -> a -> c
. exprs -> Expr Bool
updateWhere (exprs -> Expr Bool)
-> (Columns exprs (Col Expr) -> exprs)
-> Columns exprs (Col Expr)
-> Expr Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Columns exprs (Col Expr) -> exprs
forall (context :: Context) a.
Table context a =>
Columns a (Col context) -> a
fromColumns

    Projection exprs -> projection
project -> Session [a] -> Connection -> IO (Either QueryError [a])
forall a. Session a -> Connection -> IO (Either QueryError a)
Hasql.run Session [a]
session Connection
c IO (Either QueryError [a])
-> (Either QueryError [a] -> IO [a]) -> IO [a]
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (QueryError -> IO [a])
-> ([a] -> IO [a]) -> Either QueryError [a] -> IO [a]
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either QueryError -> IO [a]
forall e a. Exception e => e -> IO a
throwIO [a] -> IO [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure
      where
        session :: Session [a]
session = () -> Statement () [a] -> Session [a]
forall params result.
params -> Statement params result -> Session result
Hasql.statement () Statement () [a]
statement
        statement :: Statement () [a]
statement = ByteString -> Params () -> Result [a] -> Bool -> Statement () [a]
forall a b.
ByteString -> Params a -> Result b -> Bool -> Statement a b
Hasql.Statement ByteString
bytes Params ()
params Result [a]
decode Bool
prepare
        bytes :: ByteString
bytes = Text -> ByteString
encodeUtf8 (Text -> ByteString) -> Text -> ByteString
forall a b. (a -> b) -> a -> b
$ String -> Text
Text.pack String
sql
        params :: Params ()
params = Params ()
Hasql.noParams
        decode :: Result [a]
decode = (exprs -> projection) -> Result [a]
forall exprs projection a.
Serializable projection a =>
(exprs -> projection) -> Result [a]
decoder exprs -> projection
project
        prepare :: Bool
prepare = Bool
False
        sql :: String
sql =
          Unpackspec
  (Columns projection (Col Expr)) (Columns projection (Col Expr))
-> Table (Columns exprs (Col Expr)) (Columns exprs (Col Expr))
-> (Columns exprs (Col Expr) -> Columns exprs (Col Expr))
-> (Columns exprs (Col Expr) -> Column SqlBool)
-> (Columns exprs (Col Expr) -> Columns projection (Col Expr))
-> String
forall columnsReturned ignored columnsW columnsR.
Unpackspec columnsReturned ignored
-> Table columnsW columnsR
-> (columnsR -> columnsW)
-> (columnsR -> Column SqlBool)
-> (columnsR -> columnsReturned)
-> String
Opaleye.arrangeUpdateReturningSql
            Unpackspec
  (Columns projection (Col Expr)) (Columns projection (Col Expr))
forall a. Table Expr a => Unpackspec a a
unpackspec
            Table (Columns exprs (Col Expr)) (Columns exprs (Col Expr))
target'
            Columns exprs (Col Expr) -> Columns exprs (Col Expr)
set'
            Columns exprs (Col Expr) -> Column SqlBool
where'
            Columns exprs (Col Expr) -> Columns projection (Col Expr)
project'
          where
            target' :: Table (Columns exprs (Col Expr)) (Columns exprs (Col Expr))
target' = (Columns exprs (Col Expr) -> Columns exprs (Col Insert))
-> Table (Columns exprs (Col Insert)) (Columns exprs (Col Expr))
-> Table (Columns exprs (Col Expr)) (Columns exprs (Col Expr))
forall (p :: * -> * -> *) a b c.
Profunctor p =>
(a -> b) -> p b c -> p a c
lmap Columns exprs (Col Expr) -> Columns exprs (Col Insert)
forall exprs inserts. Inserts exprs inserts => exprs -> inserts
toInsert (Table (Columns exprs (Col Insert)) (Columns exprs (Col Expr))
 -> Table (Columns exprs (Col Expr)) (Columns exprs (Col Expr)))
-> Table (Columns exprs (Col Insert)) (Columns exprs (Col Expr))
-> Table (Columns exprs (Col Expr)) (Columns exprs (Col Expr))
forall a b. (a -> b) -> a -> b
$ TableSchema (Columns exprs (Col Name))
-> Table (Columns exprs (Col Insert)) (Columns exprs (Col Expr))
forall names exprs inserts.
(Selects names exprs, Inserts exprs inserts) =>
TableSchema names -> Table inserts exprs
table (TableSchema (Columns exprs (Col Name))
 -> Table (Columns exprs (Col Insert)) (Columns exprs (Col Expr)))
-> TableSchema (Columns exprs (Col Name))
-> Table (Columns exprs (Col Insert)) (Columns exprs (Col Expr))
forall a b. (a -> b) -> a -> b
$ names -> Columns exprs (Col Name)
forall (context :: Context) a.
Table context a =>
a -> Columns a (Col context)
toColumns (names -> Columns exprs (Col Name))
-> TableSchema names -> TableSchema (Columns exprs (Col Name))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TableSchema names
target
            set' :: Columns exprs (Col Expr) -> Columns exprs (Col Expr)
set' = exprs -> Columns exprs (Col Expr)
forall (context :: Context) a.
Table context a =>
a -> Columns a (Col context)
toColumns (exprs -> Columns exprs (Col Expr))
-> (Columns exprs (Col Expr) -> exprs)
-> Columns exprs (Col Expr)
-> Columns exprs (Col Expr)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. exprs -> exprs
set (exprs -> exprs)
-> (Columns exprs (Col Expr) -> exprs)
-> Columns exprs (Col Expr)
-> exprs
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Columns exprs (Col Expr) -> exprs
forall (context :: Context) a.
Table context a =>
Columns a (Col context) -> a
fromColumns
            where' :: Columns exprs (Col Expr) -> Column SqlBool
where' = PrimExpr -> Column SqlBool
forall b. PrimExpr -> Column b
toColumn (PrimExpr -> Column SqlBool)
-> (Columns exprs (Col Expr) -> PrimExpr)
-> Columns exprs (Col Expr)
-> Column SqlBool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Expr Bool -> PrimExpr
forall a. Expr a -> PrimExpr
toPrimExpr (Expr Bool -> PrimExpr)
-> (Columns exprs (Col Expr) -> Expr Bool)
-> Columns exprs (Col Expr)
-> PrimExpr
forall b c a. (b -> c) -> (a -> b) -> a -> c
. exprs -> Expr Bool
updateWhere (exprs -> Expr Bool)
-> (Columns exprs (Col Expr) -> exprs)
-> Columns exprs (Col Expr)
-> Expr Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Columns exprs (Col Expr) -> exprs
forall (context :: Context) a.
Table context a =>
Columns a (Col context) -> a
fromColumns
            project' :: Columns exprs (Col Expr) -> Columns projection (Col Expr)
project' = Columns projection (Col Expr) -> Columns projection (Col Expr)
forall a. Table Expr a => a -> a
castTable (Columns projection (Col Expr) -> Columns projection (Col Expr))
-> (Columns exprs (Col Expr) -> Columns projection (Col Expr))
-> Columns exprs (Col Expr)
-> Columns projection (Col Expr)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. projection -> Columns projection (Col Expr)
forall (context :: Context) a.
Table context a =>
a -> Columns a (Col context)
toColumns (projection -> Columns projection (Col Expr))
-> (Columns exprs (Col Expr) -> projection)
-> Columns exprs (Col Expr)
-> Columns projection (Col Expr)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. exprs -> projection
project (exprs -> projection)
-> (Columns exprs (Col Expr) -> exprs)
-> Columns exprs (Col Expr)
-> projection
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Columns exprs (Col Expr) -> exprs
forall (context :: Context) a.
Table context a =>
Columns a (Col context) -> a
fromColumns

  where
    decoder :: forall exprs projection a. Serializable projection a
      => (exprs -> projection) -> Hasql.Result [a]
    decoder :: (exprs -> projection) -> Result [a]
decoder exprs -> projection
_ = Row a -> Result [a]
forall a. Row a -> Result [a]
Hasql.rowList (Serializable projection a => Row a
forall exprs a. Serializable exprs a => Row a
parse @projection @a)