{-# LANGUAGE LambdaCase #-} module Control.Final where import Language.Haskell.TH import Control.Monad import Data.Char (toLower) import Data.List (foldl') import Data.Foldable (foldr') -- Substitute any occurrence of 'typ' with 'repl' throughout the given type. substituteType :: Type -> Type -> Type -> Type substituteType typ repl = go where go s | s == typ = repl | otherwise = case s of ForallT bs c t -> ForallT bs c (go t) AppT t1 t2 -> AppT (go t1) (go t2) SigT t k -> SigT (go t) k _ -> s -- Turn r and [a,b,c] into: a -> b -> c -> r funType :: Type -> [Type] -> Type funType = foldr' (AppT . AppT ArrowT) -- Turn f and [a,b,c] into: f a b c funCall :: Exp -> [Exp] -> Exp funCall = foldl' AppE makeFinalIso :: String -> Name -> DecsQ makeFinalIso dest = reify >=> \case TyConI (DataD ctx name binders _ ctors _) -> makeFinalType dest ctx name (map simplifyBinder binders) ctors TyConI (NewtypeD ctx name binders _ ctor _) -> makeFinalType dest ctx name (map simplifyBinder binders) [ctor] _ -> error "makeFinalIso only accepts plain ADTs (data or newtype)" where simplifyBinder (KindedTV b StarT) = PlainTV b simplifyBinder b = b makeFinalType :: String -> Cxt -> Name -> [TyVarBndr] -> [Con] -> DecsQ makeFinalType dest ctx name binders ctors = do -- The final encoding is named StreamR -- The final unwrapper is foldStreamR -- The final type is 'forall r. (a -> r -> r) -> r' -- newtype StreamR a = StreamR { foldStreamR :: 'foldType' } ns <- bang noSourceUnpackedness noSourceStrictness let nameR = mkName dest foldNameR = mkName ("fold" ++ dest) nameR' = mkName dest newType = NewtypeD ctx nameR binders Nothing (RecC nameR' [(foldNameR, ns, foldType)]) [] -- toStreamR :: Stream a -> StreamR a let nameToNameR = mkName ("to" ++ dest) funs <- ctorFuns matches <- ctorMatches funs foldNameR nameToNameR let xname = mkName "x" lam = LamE (map VarP funs) (CaseE (VarE xname) matches) toBody = NormalB (AppE (ConE nameR) lam) toNameR = [ SigD nameToNameR (funSig ArrowT name nameR) , FunD nameToNameR [Clause [VarP xname] toBody []] ] -- fromStreamR :: StreamR a -> Stream a let fromBody = NormalB (funCall (VarE foldNameR) (VarE xname : map (ConE . ctorName) ctors)) nameFromNameR = mkName ("from" ++ dest) fromNameR = [ SigD nameFromNameR (funSig ArrowT nameR name) , FunD nameFromNameR [Clause [VarP xname] fromBody []] ] -- If Lens is imported, make isoStreamR :: Iso' (Stream a) (StreamR a) miso <- lookupValueName "Control.Lens.iso" misoTyp <- lookupTypeName "Control.Lens.Iso'" let isoStreamR = mkName ("iso" ++ dest) lensIsoBody iso = NormalB (AppE (AppE (VarE iso) (VarE nameToNameR)) (VarE nameFromNameR)) lensIso = case (miso, misoTyp) of (Just iso, Just iso') -> [ SigD isoStreamR (funSig (ConT iso') name nameR) , FunD isoStreamR [Clause [] (lensIsoBody iso) []] ] _ -> [] return $ [newType] ++ toNameR ++ fromNameR ++ lensIso where -- The foldType is the Church-encoding of the input ADT foldType = let r = mkName "r" rt = VarT r in substituteType nameBaseType rt $ ForallT [PlainTV r] [] $ funType rt $ map (ctorToFunc rt) ctors where ctorToFunc r (NormalC _ ts) = funType r (map snd ts) ctorToFunc r (RecC _ ts) = funType r (map (\(_,_,x) -> x) ts) ctorToFunc r (InfixC t1 _ t2) = funType r [snd t1, snd t2] ctorToFunc r (ForallC bs ct co) = ForallT bs ct (ctorToFunc r co) ctorToFunc _ GadtC {} = error "Unsupported: GadtC" ctorToFunc _ RecGadtC {} = error "Unsupported: RecGadtC" -- Return just the name of the constructor ctorName (NormalC n _) = n ctorName (RecC n _) = n ctorName (InfixC _ n _) = n ctorName (ForallC _ _ co) = ctorName co ctorName (GadtC _ns _ _) = error "Unsupported: GadtC" ctorName (RecGadtC _ns _ _) = error "Unsupported: RecGadtC" -- Return a list of function names based on our constructor names, where -- FooBar becomes fooBar, etc. ctorFuns = mapM (newName . (\n -> toLower (head n) : tail n) . nameBase . ctorName) ctors -- Return a "case of" construction that matches against the input ADT ctorMatches funs foldNameR nameToNameR = forM (zip funs ctors) $ \(f, c) -> do let poss = ctorArgCount c args <- replicateM (length poss) (newName "a") return $ Match (ConP (ctorName c) (map VarP args)) (NormalB (foldl' app (VarE f) (zip (map VarE args) poss))) [] where ctorArgCount (NormalC _ ts) = map (\(_,t) -> t == nameBaseType) ts ctorArgCount (RecC _ ts) = map (\(_,_,t) -> t == nameBaseType) ts ctorArgCount (InfixC t1 _ t2) = [snd t1 == nameBaseType, snd t2 == nameBaseType] ctorArgCount (ForallC _ _ co) = ctorArgCount co ctorArgCount (GadtC _ns _ _) = error "Unsupported: GadtC" ctorArgCount (RecGadtC _ns _ _) = error "Unsupported: RecGadtC" app acc (arg, recurse) = AppE acc $ if recurse then funCall (AppE (VarE foldNameR) (AppE (VarE nameToNameR) arg)) (map VarE funs) else arg -- Return a fully applied "base type" using the structure of the input -- ADT. For example " a b c", if the ADT takes three type variables. baseType nm = foldl' (\acc -> AppT acc . VarT) (ConT nm) $ map (\case PlainTV n -> n KindedTV n _ -> n) binders -- The original name of the input ADT, fully applied. This is used to -- discover recursive occurrences. nameBaseType = baseType name -- Create a function signature based on the input ADT's binders funSig t nF nT = let ty = AppT (AppT t (baseType nF)) (baseType nT) in if null binders then ty else ForallT binders ctx ty