{-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE TemplateHaskell #-}

module DeriveHasField (
  module GHC.Records,
  deriveHasFieldWith,
)
where

import Control.Monad
import Data.Char (toLower)
import Data.Foldable as Foldable
import Data.Traversable (for)
import GHC.Records
import Language.Haskell.TH
import Language.Haskell.TH.Datatype

deriveHasFieldWith :: (String -> String) -> Name -> DecsQ
deriveHasFieldWith :: (String -> String) -> Name -> DecsQ
deriveHasFieldWith String -> String
fieldModifier = (String -> String) -> DatatypeInfo -> DecsQ
makeDeriveHasField String -> String
fieldModifier (DatatypeInfo -> DecsQ)
-> (Name -> Q DatatypeInfo) -> Name -> DecsQ
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< Name -> Q DatatypeInfo
reifyDatatype

makeDeriveHasField :: (String -> String) -> DatatypeInfo -> DecsQ
makeDeriveHasField :: (String -> String) -> DatatypeInfo -> DecsQ
makeDeriveHasField String -> String
fieldModifier DatatypeInfo
datatypeInfo = do
  -- We do not support sum of product types
  ConstructorInfo
constructorInfo <- case DatatypeInfo
datatypeInfo.datatypeCons of
    [ConstructorInfo
info] -> ConstructorInfo -> Q ConstructorInfo
forall a. a -> Q a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ConstructorInfo
info
    [ConstructorInfo]
_ -> String -> Q ConstructorInfo
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"deriveHasField: only supports product types with a single data constructor"

  -- We only support data and newtype declarations
  Bool -> Q () -> Q ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (DatatypeInfo
datatypeInfo.datatypeVariant DatatypeVariant -> [DatatypeVariant] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`Foldable.notElem` [DatatypeVariant
Datatype, DatatypeVariant
Newtype]) (Q () -> Q ()) -> Q () -> Q ()
forall a b. (a -> b) -> a -> b
$
    String -> Q ()
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"deriveHasField: only supports data and newtype"

  -- We only support data types with field names and concrete types
  let isConcreteType :: Type -> Bool
isConcreteType = \case
        ConT Name
_ -> Bool
True
        AppT Type
_ Type
_ -> Bool
True
        Type
_ -> Bool
False
  [Name]
recordConstructorNames <- case ConstructorInfo
constructorInfo.constructorVariant of
    RecordConstructor [Name]
names -> [Name] -> Q [Name]
forall a. a -> Q a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Name]
names
    ConstructorVariant
_ -> String -> Q [Name]
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"deriveHasField: only supports constructors with field names"
  Bool -> Q () -> Q ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ((Type -> Bool) -> [Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
Foldable.all Type -> Bool
isConcreteType ConstructorInfo
constructorInfo.constructorFields) (Q () -> Q ()) -> Q () -> Q ()
forall a b. (a -> b) -> a -> b
$
    String -> Q ()
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"deriveHasField: only supports concrete field types"

  -- Build the instances
  let constructorNamesAndTypes :: [(Name, Type)]
      constructorNamesAndTypes :: [(Name, Type)]
constructorNamesAndTypes = [Name] -> [Type] -> [(Name, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Name]
recordConstructorNames ConstructorInfo
constructorInfo.constructorFields
      parentType :: Q Type
parentType =
        (Q Type -> TyVarBndr () -> Q Type)
-> Q Type -> [TyVarBndr ()] -> Q Type
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl'
          (\Q Type
acc TyVarBndr ()
var -> Q Type -> Q Type -> Q Type
forall (m :: * -> *). Quote m => m Type -> m Type -> m Type
appT Q Type
acc (Name -> Q Type
forall (m :: * -> *). Quote m => Name -> m Type
varT (Name -> Q Type) -> Name -> Q Type
forall a b. (a -> b) -> a -> b
$ TyVarBndr () -> Name
forall flag. TyVarBndr flag -> Name
tyVarBndrToName TyVarBndr ()
var))
          (Name -> Q Type
forall (m :: * -> *). Quote m => Name -> m Type
conT DatatypeInfo
datatypeInfo.datatypeName)
          DatatypeInfo
datatypeInfo.datatypeVars
  [Decs]
decs <- [(Name, Type)] -> ((Name, Type) -> DecsQ) -> Q [Decs]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
t a -> (a -> f b) -> f (t b)
for [(Name, Type)]
constructorNamesAndTypes (((Name, Type) -> DecsQ) -> Q [Decs])
-> ((Name, Type) -> DecsQ) -> Q [Decs]
forall a b. (a -> b) -> a -> b
$ \(Name
name, Type
ty) ->
    let currentFieldName :: String
currentFieldName = Name -> String
nameBase Name
name
        wantedFieldName :: String
wantedFieldName = String -> String
lowerFirst (String -> String) -> String -> String
forall a b. (a -> b) -> a -> b
$ String -> String
fieldModifier String
currentFieldName
        litTCurrentField :: Q Type
litTCurrentField = Q TyLit -> Q Type
forall (m :: * -> *). Quote m => m TyLit -> m Type
litT (Q TyLit -> Q Type) -> Q TyLit -> Q Type
forall a b. (a -> b) -> a -> b
$ String -> Q TyLit
forall (m :: * -> *). Quote m => String -> m TyLit
strTyLit String
currentFieldName
        litTFieldWanted :: Q Type
litTFieldWanted = Q TyLit -> Q Type
forall (m :: * -> *). Quote m => m TyLit -> m Type
litT (Q TyLit -> Q Type) -> Q TyLit -> Q Type
forall a b. (a -> b) -> a -> b
$ String -> Q TyLit
forall (m :: * -> *). Quote m => String -> m TyLit
strTyLit String
wantedFieldName
     in if String
currentFieldName String -> String -> Bool
forall a. Eq a => a -> a -> Bool
== String
wantedFieldName
          then String -> DecsQ
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"deriveHasField: after applying fieldModifier, field didn't change"
          else
            [d|
              instance HasField $Q Type
litTFieldWanted $Q Type
parentType $(Type -> Q Type
forall a. a -> Q a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Type
ty) where
                getField = $(Q Exp -> Q Type -> Q Exp
forall (m :: * -> *). Quote m => m Exp -> m Type -> m Exp
appTypeE (Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE (Name -> Q Exp) -> Name -> Q Exp
forall a b. (a -> b) -> a -> b
$ String -> Name
mkName String
"getField") Q Type
litTCurrentField)
              |]
  Decs -> DecsQ
forall a. a -> Q a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Decs -> DecsQ) -> Decs -> DecsQ
forall a b. (a -> b) -> a -> b
$ [Decs] -> Decs
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
Foldable.concat [Decs]
decs

lowerFirst :: String -> String
lowerFirst :: String -> String
lowerFirst = \case
  [] -> []
  (Char
x : String
xs) -> Char -> Char
toLower Char
x Char -> String -> String
forall a. a -> [a] -> [a]
: String
xs

tyVarBndrToName :: TyVarBndr flag -> Name
tyVarBndrToName :: forall flag. TyVarBndr flag -> Name
tyVarBndrToName = \case
  PlainTV Name
name flag
_ -> Name
name
  KindedTV Name
name flag
_ Type
_ -> Name
name