{-# LANGUAGE CPP #-}
module Composite.Opaleye.TH where

import Control.Lens ((<&>))
import qualified Data.ByteString.Char8 as BSC8
import Data.List.Split (splitOn)
import Data.Maybe (fromMaybe)
import Data.Profunctor.Product.Default (Default, def)
import Data.Traversable (for)
import Database.PostgreSQL.Simple (ResultError(ConversionFailed, Incompatible, UnexpectedNull))
import Database.PostgreSQL.Simple.FromField (FromField, fromField, typename, returnError)
import Language.Haskell.TH
  ( Q, Name, mkName, nameBase, newName, pprint, reify
  , Info(TyConI), Dec(DataD), Con(NormalC)
  , conT
  , dataD, instanceD
  , lamE, varE, caseE, conE
  , conP, varP, wildP, litP, stringL
  , caseE, match
  , funD, clause
  , normalB, normalGE, guardedB
  , cxt
  )
import Language.Haskell.TH.Syntax (lift)
import Opaleye
  ( DefaultFromField, Field, ToFields, fromPGSFromField, defaultFromField, toToFields
  )
import Opaleye.Internal.PGTypes (IsSqlType, showSqlType, literalColumn)
import Opaleye.Internal.HaskellDB.PrimQuery (Literal(StringLit))

getLastComponent :: String -> String
getLastComponent :: [Char] -> [Char]
getLastComponent [Char]
str = case forall a. [a] -> [a]
reverse (forall a. Eq a => [a] -> [a] -> [[a]]
splitOn [Char]
"." [Char]
str) of
  [Char]
x:[[Char]]
_ -> [Char]
x
  [] -> [Char]
str

-- |Derive the various instances required to make a Haskell enumeration map to a PostgreSQL @enum@ type.
--
-- In @deriveOpaleyeEnum ''HaskellType "schema.sqltype" hsConToSqlValue@, @''HaskellType@ is the sum type (data declaration) to make instances for, 
-- @"schema.sqltype"@ is the PostgreSQL type name, and @hsConToSqlValue@ is a function to map names of constructors to SQL values.
--
-- The function @hsConToSqlValue@ is of the type @String -> Maybe String@ in order to make using 'stripPrefix' convenient. The function is applied to each
-- constructor name and for @Just value@ that value is used, otherwise for @Nothing@ the constructor name is used.
--
-- For example, given the Haskell type:
--
-- @
--     data MyEnum = MyFoo | MyBar
-- @
--
-- And PostgreSQL type:
--
-- @
--     CREATE TYPE myenum AS ENUM('foo', 'bar');
-- @
--
-- The splice:
--
-- @
--     deriveOpaleyeEnum ''MyEnum "myschema.myenum" ('stripPrefix' "my" . 'map' 'toLower')
-- @
--
-- Will create @PGMyEnum@ and instances required to use @MyEnum@ / @Field MyEnum@ in Opaleye.
--
-- The Haskell generated by this splice for the example is something like:
--
-- @
--     data PGMyEnum
--
--     instance 'IsSqlType' PGMyEnum where
--       'showSqlType' _ = "myschema.myenum"
--
--     instance 'FromField' MyEnum where
--       'fromField' f mbs = do
--         tname <- 'typename' f
--         case mbs of
--           _ | 'getLastComponent' ('BSC8.unpack' tname) /= "myenum" -> 'returnError' 'Incompatible' f ""
--           Just "foo" -> pure MyFoo
--           Just "bar" -> pure MyBar
--           Just other -> 'returnError' 'ConversionFailed' f ("Unexpected myschema.myenum value: " <> 'BSC8.unpack' other)
--           Nothing    -> 'returnError' 'UnexpectedNull' f ""
--
--     instance 'DefaultFromField' PGMyEnum MyEnum where
--       defaultFromField = 'fromPGSFromField'
--
--     instance 'Default' 'ToFields' MyEnum ('Field' PGMyEnum) where
--       def = 'toToFields' $ \ a ->
--         'literalColumn' . 'stringLit' $ case a of
--           MyFoo -> "foo"
--           MyBar -> "bar"
-- @
deriveOpaleyeEnum :: Name -> String -> (String -> Maybe String) -> Q [Dec]
deriveOpaleyeEnum :: Name -> [Char] -> ([Char] -> Maybe [Char]) -> Q [Dec]
deriveOpaleyeEnum Name
hsName [Char]
sqlName [Char] -> Maybe [Char]
hsConToSqlValue = do
  let sqlTypeName :: Name
sqlTypeName = [Char] -> Name
mkName forall a b. (a -> b) -> a -> b
$ [Char]
"PG" forall a. [a] -> [a] -> [a]
++ Name -> [Char]
nameBase Name
hsName
      sqlType :: Q Type
sqlType = forall (m :: * -> *). Quote m => Name -> m Type
conT Name
sqlTypeName
      hsType :: Q Type
hsType = forall (m :: * -> *). Quote m => Name -> m Type
conT Name
hsName
      unqualSqlName :: [Char]
unqualSqlName = [Char] -> [Char]
getLastComponent [Char]
sqlName

  [Con]
rawCons <- Name -> Q Info
reify Name
hsName forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \ case
    TyConI (DataD Cxt
_cxt Name
_name [TyVarBndr ()]
_tvVarBndrs Maybe Type
_maybeKind [Con]
cons [DerivClause]
_derivingCxt) ->
      forall (f :: * -> *) a. Applicative f => a -> f a
pure [Con]
cons
    Info
other ->
      forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail forall a b. (a -> b) -> a -> b
$ [Char]
"expected " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> [Char]
show Name
hsName forall a. Semigroup a => a -> a -> a
<> [Char]
" to name a data declaration, not:\n" forall a. Semigroup a => a -> a -> a
<> forall a. Ppr a => a -> [Char]
pprint Info
other

  [Name]
nullaryCons <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
t a -> (a -> f b) -> f (t b)
for [Con]
rawCons forall a b. (a -> b) -> a -> b
$ \ case
    NormalC Name
conName [] ->
      forall (f :: * -> *) a. Applicative f => a -> f a
pure Name
conName
    Con
other ->
      forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail forall a b. (a -> b) -> a -> b
$ [Char]
"expected every constructor of " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> [Char]
show Name
hsName forall a. Semigroup a => a -> a -> a
<> [Char]
" to be a regular nullary constructor, not:\n" forall a. Semigroup a => a -> a -> a
<> forall a. Ppr a => a -> [Char]
pprint Con
other

  let conPairs :: [(Name, [Char])]
conPairs = [Name]
nullaryCons forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> \ Name
conName ->
        (Name
conName, forall a. a -> Maybe a -> a
fromMaybe (Name -> [Char]
nameBase Name
conName) ([Char] -> Maybe [Char]
hsConToSqlValue (Name -> [Char]
nameBase Name
conName)))

  Dec
sqlTypeDecl <-
    forall (m :: * -> *).
Quote m =>
m Cxt
-> Name
-> [TyVarBndr ()]
-> Maybe Type
-> [m Con]
-> [m DerivClause]
-> m Dec
dataD
      (forall (m :: * -> *). Quote m => [m Type] -> m Cxt
cxt [])
      Name
sqlTypeName
      []
      forall a. Maybe a
Nothing
      []
#if MIN_VERSION_template_haskell(2,12,0)
      []
#else
      (cxt [])
#endif

  Dec
isSqlTypeInst <- forall (m :: * -> *).
Quote m =>
m Cxt -> m Type -> [m Dec] -> m Dec
instanceD (forall (m :: * -> *). Quote m => [m Type] -> m Cxt
cxt []) [t| IsSqlType $sqlType |] forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a. a -> [a] -> [a]
:[]) forall a b. (a -> b) -> a -> b
$ do
    forall (m :: * -> *). Quote m => Name -> [m Clause] -> m Dec
funD 'showSqlType
      [ forall (m :: * -> *).
Quote m =>
[m Pat] -> m Body -> [m Dec] -> m Clause
clause
          [forall (m :: * -> *). Quote m => m Pat
wildP]
          (forall (m :: * -> *). Quote m => m Exp -> m Body
normalB (forall t (m :: * -> *). (Lift t, Quote m) => t -> m Exp
lift [Char]
sqlName))
          []
      ]

  Dec
fromFieldInst <- forall (m :: * -> *).
Quote m =>
m Cxt -> m Type -> [m Dec] -> m Dec
instanceD (forall (m :: * -> *). Quote m => [m Type] -> m Cxt
cxt []) [t| FromField $hsType |] forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a. a -> [a] -> [a]
:[]) forall a b. (a -> b) -> a -> b
$ do
    Name
field <- forall (m :: * -> *). Quote m => [Char] -> m Name
newName [Char]
"field"
    Name
mbs   <- forall (m :: * -> *). Quote m => [Char] -> m Name
newName [Char]
"mbs"
    Name
tname <- forall (m :: * -> *). Quote m => [Char] -> m Name
newName [Char]
"tname"
    Name
other <- forall (m :: * -> *). Quote m => [Char] -> m Name
newName [Char]
"other"

    let bodyCase :: Q Exp
bodyCase = forall (m :: * -> *). Quote m => m Exp -> [m Match] -> m Exp
caseE (forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
mbs) forall a b. (a -> b) -> a -> b
$
          [ forall (m :: * -> *).
Quote m =>
m Pat -> m Body -> [m Dec] -> m Match
match
              forall (m :: * -> *). Quote m => m Pat
wildP
              (forall (m :: * -> *). Quote m => [m (Guard, Exp)] -> m Body
guardedB [ forall (m :: * -> *). Quote m => m Exp -> m Exp -> m (Guard, Exp)
normalGE [| getLastComponent (BSC8.unpack $(varE tname)) /= $(lift unqualSqlName) |]
                                   [| returnError Incompatible $(varE field) "" |] ])
              []
          ] forall a. [a] -> [a] -> [a]
++
          (
            [(Name, [Char])]
conPairs forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> \ (Name
conName, [Char]
value) ->
              forall (m :: * -> *).
Quote m =>
m Pat -> m Body -> [m Dec] -> m Match
match
                [p| Just $(litP $ stringL value) |]
                (forall (m :: * -> *). Quote m => m Exp -> m Body
normalB [| pure $(conE conName) |])
                []
          ) forall a. [a] -> [a] -> [a]
++
          [ forall (m :: * -> *).
Quote m =>
m Pat -> m Body -> [m Dec] -> m Match
match 
              [p| Just $(varP other) |]
              (forall (m :: * -> *). Quote m => m Exp -> m Body
normalB [| returnError ConversionFailed $(varE field) ("Unexpected " <> $(lift sqlName) <> " value: " <> BSC8.unpack $(varE other)) |])
              []
          , forall (m :: * -> *).
Quote m =>
m Pat -> m Body -> [m Dec] -> m Match
match
              [p| Nothing |]
              (forall (m :: * -> *). Quote m => m Exp -> m Body
normalB [| returnError UnexpectedNull $(varE field) "" |])
              []
          ]

    forall (m :: * -> *). Quote m => Name -> [m Clause] -> m Dec
funD 'fromField
      [ forall (m :: * -> *).
Quote m =>
[m Pat] -> m Body -> [m Dec] -> m Clause
clause
          [forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
field, forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
mbs]
          (forall (m :: * -> *). Quote m => m Exp -> m Body
normalB [|
            do
              $(varP tname) <- typename $(varE field)
              $bodyCase
            |])
          []
      ]

  Dec
defaultFromFieldInst <- forall (m :: * -> *).
Quote m =>
m Cxt -> m Type -> [m Dec] -> m Dec
instanceD (forall (m :: * -> *). Quote m => [m Type] -> m Cxt
cxt []) [t| DefaultFromField $sqlType $hsType |] forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a. a -> [a] -> [a]
:[]) forall a b. (a -> b) -> a -> b
$
    forall (m :: * -> *). Quote m => Name -> [m Clause] -> m Dec
funD 'defaultFromField
      [ forall (m :: * -> *).
Quote m =>
[m Pat] -> m Body -> [m Dec] -> m Clause
clause
          []
          (forall (m :: * -> *). Quote m => m Exp -> m Body
normalB [| fromPGSFromField |])
          []
      ]

  Dec
defaultInst <- forall (m :: * -> *).
Quote m =>
m Cxt -> m Type -> [m Dec] -> m Dec
instanceD (forall (m :: * -> *). Quote m => [m Type] -> m Cxt
cxt []) [t| Default ToFields $hsType (Field $sqlType) |] forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a. a -> [a] -> [a]
:[]) forall a b. (a -> b) -> a -> b
$ do
    Name
s <- forall (m :: * -> *). Quote m => [Char] -> m Name
newName [Char]
"s"
    let body :: Q Exp
body = forall (m :: * -> *). Quote m => [m Pat] -> m Exp -> m Exp
lamE [forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
s] forall a b. (a -> b) -> a -> b
$
          forall (m :: * -> *). Quote m => m Exp -> [m Match] -> m Exp
caseE (forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
s) forall a b. (a -> b) -> a -> b
$
            [(Name, [Char])]
conPairs forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> \ (Name
conName, [Char]
value) ->
              forall (m :: * -> *).
Quote m =>
m Pat -> m Body -> [m Dec] -> m Match
match
                (forall (m :: * -> *). Quote m => Name -> [m Pat] -> m Pat
conP Name
conName [])
                (forall (m :: * -> *). Quote m => m Exp -> m Body
normalB forall a b. (a -> b) -> a -> b
$ forall t (m :: * -> *). (Lift t, Quote m) => t -> m Exp
lift [Char]
value)
                []

    forall (m :: * -> *). Quote m => Name -> [m Clause] -> m Dec
funD 'def
      [ forall (m :: * -> *).
Quote m =>
[m Pat] -> m Body -> [m Dec] -> m Clause
clause
          []
          (forall (m :: * -> *). Quote m => m Exp -> m Body
normalB [| toToFields (literalColumn . StringLit . $body) |])
          []
      ]

  forall (f :: * -> *) a. Applicative f => a -> f a
pure [Dec
sqlTypeDecl, Dec
isSqlTypeInst, Dec
fromFieldInst, Dec
defaultFromFieldInst, Dec
defaultInst]