module Data.Label.Derive
( mkLabels
, mkLabelsNoTypes
) where
import Control.Arrow
import Control.Category
import Control.Monad
import Data.Char
import Data.Function (on)
import Data.Label.Abstract
import Data.List
import Data.Ord
import Data.String
import Language.Haskell.TH
import Language.Haskell.TH.Syntax
import Prelude hiding ((.), id)
fclError :: String -> a
fclError err = error ("Data.Label.Derive: " ++ err)
mkLabels :: [Name] -> Q [Dec]
mkLabels = liftM concat . mapM (derive1 True)
mkLabelsNoTypes :: [Name] -> Q [Dec]
mkLabelsNoTypes = liftM concat . mapM (derive1 False)
derive1 :: Bool -> Name -> Q [Dec]
derive1 signatures datatype =
do i <- reify datatype
let
(tyname, cons, vars) =
case i of
TyConI (DataD _ n vs cs _) -> (n, cs, vs)
TyConI (NewtypeD _ n vs c _) -> (n, [c], vs)
_ -> fclError "Can only derive labels for datatypes and newtypes."
recordOnly = groupByCtor [ (f, n) | RecC n fs <- cons, f <- fs ]
concat `liftM` mapM (derive signatures tyname vars (length cons)) recordOnly
where groupByCtor = map (\xs -> (fst (head xs), map snd xs))
. groupBy ((==) `on` (fst3 . fst))
. sortBy (comparing (fst3 . fst))
where fst3 (a, _, _) = a
derive :: Bool -> Name -> [TyVarBndr] -> Int -> (VarStrictType, [Name]) -> Q [Dec]
derive signatures tyname vars total ((field, _, fieldtyp), ctors) =
do (sign, body) <-
if length ctors == total
then function derivePureLabel
else function deriveMaybeLabel
return $
if signatures
then [sign, body]
else [body]
where
deriveMaybeLabel = (sign, body)
where
sign = forallT vars (return []) [t| (ArrowChoice (~>), ArrowZero (~>)) => Lens (~>) $(inputType) $(return fieldtyp) |]
body = [| let c = zeroArrow ||| returnA in lens (c . $(getter)) (c . $(setter)) |]
where
getter = [| arr (\ p -> $(caseE [|p|] (cases (bodyG [|p|] ) ++ wild))) |]
setter = [| arr (\(v, p) -> $(caseE [|p|] (cases (bodyS [|p|] [|v|]) ++ wild))) |]
cases b = map (\ctor -> match (recP ctor []) (normalB b) []) ctors
wild = [match wildP (normalB [| Left () |]) []]
bodyS p v = [| Right $( record p fieldName v ) |]
bodyG p = [| Right $( fromString fieldName `appE` p ) |]
derivePureLabel = (sign, body)
where
sign = forallT vars (return []) [t| Arrow (~>) => Lens (~>) $(inputType) $(return fieldtyp) |]
body = [| lens $(getter) $(setter) |]
where
getter = [| arr $(fromString fieldName) |]
setter = [| arr (\(v, p) -> $(record [| p |] fieldName [| v |])) |]
fieldName = nameBase field
labelName = mkName $
case nameBase field of
'_' : c : rest -> toLower c : rest
f : rest -> 'l' : toUpper f : rest
n -> fclError ("Cannot derive label for record selector with name: " ++ n)
inputType = return $ foldr (flip AppT) (ConT tyname) (map tvToVarT (reverse vars))
tvToVarT (PlainTV tv) = VarT tv
tvToVarT _ = fclError "No support for special-kinded type variables."
record rec fld val = val >>= \v -> recUpdE rec [return (mkName fld, v)]
function (s, b) = liftM2 (,)
(sigD labelName s)
(funD labelName [ clause [] (normalB b) [] ])
instance IsString Exp where
fromString = VarE . mkName
instance IsString (Q Pat) where
fromString = varP . mkName
instance IsString (Q Exp) where
fromString = varE . mkName