{-|
Module: Squeal.PostgreSQL.Update
Description: update statements
Copyright: (c) Eitan Chatav, 2019
Maintainer: eitan@morphism.tech
Stability: experimental

update statements
-}

{-# LANGUAGE
    DeriveGeneric
  , DerivingStrategies
  , FlexibleContexts
  , FlexibleInstances
  , GADTs
  , GeneralizedNewtypeDeriving
  , LambdaCase
  , MultiParamTypeClasses
  , OverloadedStrings
  , PatternSynonyms
  , QuantifiedConstraints
  , RankNTypes
  , ScopedTypeVariables
  , TypeApplications
  , TypeFamilies
  , DataKinds
  , PolyKinds
  , TypeOperators
  , UndecidableInstances
#-}

module Squeal.PostgreSQL.Manipulation.Update
  ( -- * Update
    update
  , update_
  ) where

import Data.ByteString hiding (foldr)
import GHC.TypeLits

import qualified Generics.SOP as SOP

import Squeal.PostgreSQL.Type.Alias
import Squeal.PostgreSQL.Expression
import Squeal.PostgreSQL.Expression.Default
import Squeal.PostgreSQL.Expression.Logic
import Squeal.PostgreSQL.Manipulation
import Squeal.PostgreSQL.Type.List
import Squeal.PostgreSQL.Render
import Squeal.PostgreSQL.Type.Schema

-- $setup
-- >>> import Squeal.PostgreSQL

renderUpdate
  :: (forall x. RenderSQL (expr x))
  => Aliased (Optional expr) ty
  -> ByteString
renderUpdate :: forall {k} (expr :: k -> *) (ty :: (Symbol, (Optionality, k))).
(forall (x :: k). RenderSQL (expr x)) =>
Aliased (Optional expr) ty -> ByteString
renderUpdate (Optional expr ty
expr `As` Alias alias
col) = forall sql. RenderSQL sql => sql -> ByteString
renderSQL Alias alias
col ByteString -> ByteString -> ByteString
<+> ByteString
"=" ByteString -> ByteString -> ByteString
<+> forall sql. RenderSQL sql => sql -> ByteString
renderSQL Optional expr ty
expr

{-----------------------------------------
UPDATE statements
-----------------------------------------}

{- | An `update` command changes the values of the specified columns
in all rows that satisfy the condition.

>>> type Columns = '["col1" ::: 'Def :=> 'NotNull 'PGint4, "col2" ::: 'NoDef :=> 'NotNull 'PGint4]
>>> type Schema = '["tab1" ::: 'Table ('[] :=> Columns), "tab2" ::: 'Table ('[] :=> Columns)]
>>> :{
let
  manp :: Manipulation with (Public Schema) '[]
    '["col1" ::: 'NotNull 'PGint4,
      "col2" ::: 'NotNull 'PGint4]
  manp = update
    (#tab1 `as` #t1)
    (Set (2 + #t2 ! #col2) `as` #col1)
    (Using (table (#tab2 `as` #t2)))
    (#t1 ! #col1 ./= #t2 ! #col2)
    (Returning (#t1 & DotStar))
in printSQL manp
:}
UPDATE "tab1" AS "t1" SET "col1" = ((2 :: int4) + "t2"."col2") FROM "tab2" AS "t2" WHERE ("t1"."col1" <> "t2"."col2") RETURNING "t1".*
-}
update
  :: ( Has sch db schema
     , Has tab0 schema ('Table table)
     , Updatable table updates
     , SOP.SListI row )
  => Aliased (QualifiedAlias sch) (tab ::: tab0) -- ^ table to update
  -> NP (Aliased (Optional (Expression 'Ungrouped '[] with db params (tab ::: TableToRow table ': from)))) updates
  -- ^ update expressions, modified values to replace old values
  -> UsingClause with db params from
  -- ^ FROM A table expression allowing columns from other tables to appear
  -- in the WHERE condition and update expressions.
  -> Condition  'Ungrouped '[] with db params (tab ::: TableToRow table ': from)
  -- ^ WHERE condition under which to perform update on a row
  -> ReturningClause with db params (tab ::: TableToRow table ': from) row -- ^ results to return
  -> Manipulation with db params row
update :: forall (sch :: Symbol) (db :: [(Symbol, [(Symbol, SchemumType)])])
       (schema :: [(Symbol, SchemumType)]) (tab0 :: Symbol)
       (table :: TableType)
       (updates :: [(Symbol, (Optionality, NullType))])
       (row :: [(Symbol, NullType)]) (tab :: Symbol) (with :: FromType)
       (params :: [NullType]) (from :: FromType).
(Has sch db schema, Has tab0 schema ('Table table),
 Updatable table updates, SListI row) =>
Aliased (QualifiedAlias sch) (tab ::: tab0)
-> NP
     (Aliased
        (Optional
           (Expression
              'Ungrouped
              '[]
              with
              db
              params
              ((tab ::: TableToRow table) : from))))
     updates
-> UsingClause with db params from
-> Condition
     'Ungrouped '[] with db params ((tab ::: TableToRow table) : from)
-> ReturningClause
     with db params ((tab ::: TableToRow table) : from) row
-> Manipulation with db params row
update (QualifiedAlias sch ty
tab0 `As` Alias alias
tab) NP
  (Aliased
     (Optional
        (Expression
           'Ungrouped
           '[]
           with
           db
           params
           ((tab ::: TableToRow table) : from))))
  updates
columns UsingClause with db params from
using Condition
  'Ungrouped '[] with db params ((tab ::: TableToRow table) : from)
wh ReturningClause
  with db params ((tab ::: TableToRow table) : from) row
returning = forall (with :: FromType)
       (db :: [(Symbol, [(Symbol, SchemumType)])]) (params :: [NullType])
       (columns :: [(Symbol, NullType)]).
ByteString -> Manipulation with db params columns
UnsafeManipulation forall a b. (a -> b) -> a -> b
$
  ByteString
"UPDATE"
  ByteString -> ByteString -> ByteString
<+> forall sql. RenderSQL sql => sql -> ByteString
renderSQL QualifiedAlias sch ty
tab0 ByteString -> ByteString -> ByteString
<+> ByteString
"AS" ByteString -> ByteString -> ByteString
<+> forall sql. RenderSQL sql => sql -> ByteString
renderSQL Alias alias
tab
  ByteString -> ByteString -> ByteString
<+> ByteString
"SET"
  ByteString -> ByteString -> ByteString
<+> forall {k} (xs :: [k]) (expression :: k -> *).
SListI xs =>
(forall (x :: k). expression x -> ByteString)
-> NP expression xs -> ByteString
renderCommaSeparated forall {k} (expr :: k -> *) (ty :: (Symbol, (Optionality, k))).
(forall (x :: k). RenderSQL (expr x)) =>
Aliased (Optional expr) ty -> ByteString
renderUpdate NP
  (Aliased
     (Optional
        (Expression
           'Ungrouped
           '[]
           with
           db
           params
           ((tab ::: TableToRow table) : from))))
  updates
columns
  forall a. Semigroup a => a -> a -> a
<> case UsingClause with db params from
using of
    UsingClause with db params from
NoUsing -> ByteString
""
    Using FromClause '[] with db params from
tables -> ByteString
" FROM" ByteString -> ByteString -> ByteString
<+> forall sql. RenderSQL sql => sql -> ByteString
renderSQL FromClause '[] with db params from
tables
  ByteString -> ByteString -> ByteString
<+> ByteString
"WHERE" ByteString -> ByteString -> ByteString
<+> forall sql. RenderSQL sql => sql -> ByteString
renderSQL Condition
  'Ungrouped '[] with db params ((tab ::: TableToRow table) : from)
wh
  forall a. Semigroup a => a -> a -> a
<> forall sql. RenderSQL sql => sql -> ByteString
renderSQL ReturningClause
  with db params ((tab ::: TableToRow table) : from) row
returning

{- | Update a row returning `Nil`.

>>> type Columns = '["col1" ::: 'Def :=> 'NotNull 'PGint4, "col2" ::: 'NoDef :=> 'NotNull 'PGint4]
>>> type Schema = '["tab" ::: 'Table ('[] :=> Columns)]
>>> :{
let
  manp :: Manipulation with (Public Schema) '[] '[]
  manp = update_ #tab (Set 2 `as` #col1) (#col1 ./= #col2)
in printSQL manp
:}
UPDATE "tab" AS "tab" SET "col1" = (2 :: int4) WHERE ("col1" <> "col2")
-}
update_
  :: ( Has sch db schema
     , Has tab0 schema ('Table table)
     , KnownSymbol tab
     , Updatable table updates )
  => Aliased (QualifiedAlias sch) (tab ::: tab0) -- ^ table to update
  -> NP (Aliased (Optional (Expression 'Ungrouped '[] with db params '[tab ::: TableToRow table]))) updates
  -- ^ modified values to replace old values
  -> Condition  'Ungrouped '[] with db params '[tab ::: TableToRow table]
  -- ^ condition under which to perform update on a row
  -> Manipulation with db params '[]
update_ :: forall (sch :: Symbol) (db :: [(Symbol, [(Symbol, SchemumType)])])
       (schema :: [(Symbol, SchemumType)]) (tab0 :: Symbol)
       (table :: TableType) (tab :: Symbol)
       (updates :: [(Symbol, (Optionality, NullType))]) (with :: FromType)
       (params :: [NullType]).
(Has sch db schema, Has tab0 schema ('Table table),
 KnownSymbol tab, Updatable table updates) =>
Aliased (QualifiedAlias sch) (tab ::: tab0)
-> NP
     (Aliased
        (Optional
           (Expression
              'Ungrouped '[] with db params '[tab ::: TableToRow table])))
     updates
-> Condition
     'Ungrouped '[] with db params '[tab ::: TableToRow table]
-> Manipulation with db params '[]
update_ Aliased (QualifiedAlias sch) (tab ::: tab0)
tab NP
  (Aliased
     (Optional
        (Expression
           'Ungrouped '[] with db params '[tab ::: TableToRow table])))
  updates
columns Condition 'Ungrouped '[] with db params '[tab ::: TableToRow table]
wh = forall (sch :: Symbol) (db :: [(Symbol, [(Symbol, SchemumType)])])
       (schema :: [(Symbol, SchemumType)]) (tab0 :: Symbol)
       (table :: TableType)
       (updates :: [(Symbol, (Optionality, NullType))])
       (row :: [(Symbol, NullType)]) (tab :: Symbol) (with :: FromType)
       (params :: [NullType]) (from :: FromType).
(Has sch db schema, Has tab0 schema ('Table table),
 Updatable table updates, SListI row) =>
Aliased (QualifiedAlias sch) (tab ::: tab0)
-> NP
     (Aliased
        (Optional
           (Expression
              'Ungrouped
              '[]
              with
              db
              params
              ((tab ::: TableToRow table) : from))))
     updates
-> UsingClause with db params from
-> Condition
     'Ungrouped '[] with db params ((tab ::: TableToRow table) : from)
-> ReturningClause
     with db params ((tab ::: TableToRow table) : from) row
-> Manipulation with db params row
update Aliased (QualifiedAlias sch) (tab ::: tab0)
tab NP
  (Aliased
     (Optional
        (Expression
           'Ungrouped '[] with db params '[tab ::: TableToRow table])))
  updates
columns forall (with :: FromType)
       (db :: [(Symbol, [(Symbol, SchemumType)])]) (params :: [NullType]).
UsingClause with db params '[]
NoUsing Condition 'Ungrouped '[] with db params '[tab ::: TableToRow table]
wh (forall (row :: [(Symbol, NullType)]) (with :: FromType)
       (db :: [(Symbol, [(Symbol, SchemumType)])]) (params :: [NullType])
       (from :: FromType).
SListI row =>
NP (Aliased (Expression 'Ungrouped '[] with db params from)) row
-> ReturningClause with db params from row
Returning_ forall {k} (a :: k -> *). NP a '[]
Nil)