module Language.HERMIT.Primitive.FixPoint ( -- * Operations on the Fixed Point Operator (fix) -- | Note that many of these operations require 'Data.Function.fix' to be in scope. Language.HERMIT.Primitive.FixPoint.externals -- ** Rewrites and BiRewrites on Fixed Points , fixIntro , fixComputationRule , rollingRule ) where import GhcPlugins as GHC hiding (varName) import Control.Applicative import Control.Arrow import Data.Monoid (mempty) import Language.HERMIT.Core import Language.HERMIT.Monad import Language.HERMIT.Kure import Language.HERMIT.External import Language.HERMIT.GHC import Language.HERMIT.Primitive.AlphaConversion import Language.HERMIT.Primitive.Common import Language.HERMIT.Primitive.GHC import Language.HERMIT.Primitive.Local import Language.HERMIT.Primitive.Navigation import Language.HERMIT.Primitive.New -- TODO: Sort out heirarchy import Language.HERMIT.Primitive.Unfold import qualified Language.Haskell.TH as TH -------------------------------------------------------------------------------------------------- -- | Externals for manipulating fixed points, and for the worker/wrapper transformation. externals :: [External] externals = map (.+ Experiment) [ external "fix-intro" (promoteDefR fixIntro :: RewriteH Core) [ "rewrite a recursive binding into a non-recursive binding using fix" ] .+ Introduce .+ Context , external "fix-computation" ((promoteExprBiR fixComputationRule) :: BiRewriteH Core) [ "Fixed-Point Computation Rule", "fix t f <==> f (fix t f)" ] , external "rolling-rule" ((promoteExprBiR rollingRule) :: BiRewriteH Core) [ "Rolling Rule", "fix tyA (\\ a -> f (g a)) <==> f (fix tyB (\\ b -> g (f b))" ] -- , external "fix-spec" (promoteExprR fixSpecialization :: RewriteH Core) -- [ "specialize a fix with a given argument"] .+ Shallow , external "ww-factorisation" ((\ wrap unwrap -> promoteExprBiR $ workerWrapperFac wrap unwrap) :: CoreString -> CoreString -> BiRewriteH Core) [ "Worker/Wrapper Factorisation", "For any \"f :: a -> a\", and given \"wrap :: b -> a\" and \"unwrap :: a -> b\" as arguments, then", "fix tyA f <==> wrap (fix tyB (\\ b -> unwrap (f (wrap b))))", "Note: the pre-condition \"fix tyA (\\ a -> wrap (unwrap (f a))) == fix tyA f\" is expected to hold." ] .+ Introduce .+ Context .+ PreCondition , external "ww-fusion" ((\ wrap unwrap work -> promoteExprBiR $ workerWrapperFusion wrap unwrap work) :: CoreString -> CoreString -> CoreString -> BiRewriteH Core) [ "Worker/Wrapper Fusion", "Given \"wrap :: b -> a\", \"unwrap :: a -> b\" and \"work :: b\" as arguments, then", "unwrap (wrap work) <==> work", "Note: the pre-conditions \"fix tyA (\\ a -> wrap (unwrap (f a))) == fix tyA f\"", " and \"work == fix (\\ b -> unwrap (f (wrap)))\" are expected to hold." ] .+ Introduce .+ Context .+ PreCondition , external "ww-split" ((\ wrap unwrap -> promoteDefR $ workerWrapperSplit wrap unwrap) :: CoreString -> CoreString -> RewriteH Core) [ "Worker/Wrapper Split", "For any \"g :: a\", and given \"wrap :: b -> a\" and \"unwrap :: a -> b\" as arguments, then", "g = expr ==> g = let f = \\ g -> expr", " in let work = unwrap (f (wrap work))", " in wrap work", "Note: the pre-condition \"fix a (wrap . unwrap . f) == fix a f\" is expected to hold." ] .+ Introduce .+ Context .+ PreCondition , external "ww-split-param" ((\ n wrap unwrap -> promoteDefR $ workerWrapperSplitParam n wrap unwrap) :: Int -> CoreString -> CoreString -> RewriteH Core) [ "Worker/Wrapper Split - Type Paramater Variant", "For any \"g :: forall t1 t2 .. tn . a\", and given \"wrap :: forall t1 t2 .. tn . b -> a\" and \"unwrap :: forall t1 t2 .. tn . a -> b\" as arguments, then", "g = expr ==> g = \\ t1 t2 .. tn -> let f = \\ g -> expr t1 t2 .. tn", " in let work = unwrap t1 t2 .. tn (f (wrap t1 t2 ..tn work))", " in wrap t1 t2 .. tn work" ] .+ Introduce .+ Context .+ PreCondition .+ TODO .+ Experiment , external "ww-assumption-A" ((\ wrap unwrap -> promoteExprBiR $ wwA wrap unwrap) :: CoreString -> CoreString -> BiRewriteH Core) [ "Worker/Wrapper Assumption A", "For a \"wrap :: b -> a\" and an \"unwrap :: b -> a\", then", "wrap (unwrap x) <==> x", "Note: only use this if it's true!" ] .+ Context .+ PreCondition , external "ww-assumption-B" ((\ wrap unwrap f -> promoteExprBiR $ wwB wrap unwrap f) :: CoreString -> CoreString -> CoreString -> BiRewriteH Core) [ "Worker/Wrapper Assumption B", "For a \"wrap :: b -> a\", an \"unwrap :: b -> a\", and an \"f :: a -> a\" then", "wrap (unwrap (f x)) <==> f x", "Note: only use this if it's true!" ] .+ Context .+ PreCondition , external "ww-assumption-C" ((\ wrap unwrap f -> promoteExprBiR $ wwC wrap unwrap f) :: CoreString -> CoreString -> CoreString -> BiRewriteH Core) [ "Worker/Wrapper Assumption C", "For a \"wrap :: b -> a\", an \"unwrap :: b -> a\", and an \"f :: a -> a\" then", "fix t (\\ x -> wrap (unwrap (f x))) <==> fix t f", "Note: only use this if it's true!" ] .+ Context .+ PreCondition ] -------------------------------------------------------------------------------------------------- -- | @f = e@ ==\> @f = fix (\\ f -> e)@ fixIntro :: RewriteH CoreDef fixIntro = prefixFailMsg "fix introduction failed: " $ do Def f _ <- idR f' <- constT $ cloneVarH id f Def f <$> (mkFix =<< (defT mempty (extractR $ substR f $ varToCoreExpr f') (\ () e' -> Lam f' e'))) -------------------------------------------------------------------------------------------------- -- | @fix ty f@ \<==\> @f (fix ty f)@ fixComputationRule :: BiRewriteH CoreExpr fixComputationRule = bidirectional computationL computationR where computationL :: RewriteH CoreExpr computationL = prefixFailMsg "fix computation rule failed: " $ do (_,f) <- isFixExpr fixf <- idR return (App f fixf) computationR :: RewriteH CoreExpr computationR = prefixFailMsg "fix computation rule failed: " $ do App f fixf <- idR (_,f') <- isFixExpr <<< constant fixf guardMsg (exprEqual f f') "external function does not match internal expression" return fixf -- | @fix tyA (\\ a -> f (g a))@ \<==\> @f (fix tyB (\\ b -> g (f b))@ rollingRule :: BiRewriteH CoreExpr rollingRule = bidirectional rollingRuleL rollingRuleR where rollingRuleL :: RewriteH CoreExpr rollingRuleL = prefixFailMsg "rolling rule failed: " $ withPatFailMsg wrongFixBody $ do (tyA, Lam a (App f (App g (Var a')))) <- isFixExpr guardMsg (a == a') wrongFixBody (tyA',tyB) <- funsWithInverseTypes g f guardMsg (eqType tyA tyA') "Type mismatch: this shouldn't have happened, report this as a bug." res <- rollingRuleResult tyB g f return (App f res) rollingRuleR :: RewriteH CoreExpr rollingRuleR = prefixFailMsg "(reversed) rolling rule failed: " $ withPatFailMsg "not an application." $ do App f fx <- idR withPatFailMsg wrongFixBody $ do (tyB, Lam b (App g (App f' (Var b')))) <- isFixExpr <<< constant fx guardMsg (b == b') wrongFixBody guardMsg (exprEqual f f') "external function does not match internal expression" (tyA,tyB') <- funsWithInverseTypes g f guardMsg (eqType tyB tyB') "Type mismatch: this shouldn't have happened, report this as a bug." rollingRuleResult tyA f g rollingRuleResult :: Type -> CoreExpr -> CoreExpr -> TranslateH z CoreExpr rollingRuleResult ty f g = do x <- constT (newIdH "x" ty) mkFix (Lam x (App f (App g (Var x)))) wrongFixBody :: String wrongFixBody = "body of fix does not have the form: Lam v (App f (App g (Var v)))" -------------------------------------------------------------------------------------------------- -- ironically, this is an instance of worker/wrapper itself. -- fixSpecialization :: RewriteH CoreExpr -- fixSpecialization = do -- -- fix (t::*) (f :: t -> t) (a :: t) :: t -- App (App (App (Var fixId) (Type _)) _) _ <- idR -- guardIsFixId fixId -- let r :: RewriteH CoreExpr -- r = multiEtaExpand [TH.mkName "f",TH.mkName "a"] -- sub :: RewriteH Core -- sub = pathR [0,1] (promoteR r) -- App (App (App (Var fx) (Type t)) -- (Lam _ (Lam v2 (App (App e _) _a2))) -- ) -- (Type t2) <- extractR sub -- In normal form now -- constT $ do let t' = applyTy t t2 -- v3 <- newIdH "f" t' -- v4 <- newTyVarH "a" (tyVarKind v2) -- -- f' :: \/ a -> T [a] -> (\/ b . T [b]) -- let f' = Lam v4 (Cast (Var v3) (mkUnsafeCo t' (applyTy t (mkTyVarTy v4)))) -- e' = Lam v3 (App (App e f') (Type t2)) -- return $ App (App (Var fx) (Type t')) e' -------------------------------------------------------------------------------------------------- -- | For any @f :: a -> a@, and given @wrap :: b -> a@ and @unwrap :: a -> b@ as arguments, then -- @fix tyA f@ \<==\> @wrap (fix tyB (\\ b -> unwrap (f (wrap b))))@ workerWrapperFacBR :: CoreExpr -> CoreExpr -> BiRewriteH CoreExpr workerWrapperFacBR wrap unwrap = beforeBiR (wrapUnwrapTypes wrap unwrap) (\ (tyA,tyB) -> bidirectional (wwL tyA tyB) wwR) where wwL :: Type -> Type -> RewriteH CoreExpr wwL tyA tyB = prefixFailMsg "worker/wrapper factorisation failed: " $ do (tA,f) <- isFixExpr guardMsg (eqType tyA tA) ("wrapper/unwrapper types do not match fix body type.") b <- constT (newIdH "x" tyB) fx <- mkFix (Lam b (App unwrap (App f (App wrap (Var b))))) return $ App wrap fx wwR :: RewriteH CoreExpr wwR = prefixFailMsg "(reverse) worker/wrapper factorisation failed: " $ withPatFailMsg "not an application." $ do App wrap2 fx <- idR withPatFailMsg wrongFixBody $ do (_, Lam b (App unwrap1 (App f (App wrap1 (Var b'))))) <- isFixExpr <<< constant fx guardMsg (b == b') wrongFixBody guardMsg (exprEqual wrap wrap2) "given wrapper does not match applied function." guardMsg (exprEqual wrap wrap1) "given wrapper does not match wrapper in body of fix." guardMsg (exprEqual unwrap unwrap1) "given unwrapper does not match unwrapper in body of fix." mkFix f wrongFixBody :: String wrongFixBody = "body of fix does not have the form Lam b (App unwrap (App f (App wrap (Var b))))" -- | For any @f :: a -> a@, and given @wrap :: b -> a@ and @unwrap :: a -> b@ as arguments, then -- @fix tyA f@ \<==\> @wrap (fix tyB (\\ b -> unwrap (f (wrap b))))@ workerWrapperFac :: CoreString -> CoreString -> BiRewriteH CoreExpr workerWrapperFac = parse2beforeBiR workerWrapperFacBR -------------------------------------------------------------------------------------------------- -- | Given @wrap :: b -> a@, @unwrap :: a -> b@ and @work :: b@ as arguments, then -- @unwrap (wrap work)@ \<==\> @work@ workerWrapperFusionBR :: CoreExpr -> CoreExpr -> CoreExpr -> BiRewriteH CoreExpr workerWrapperFusionBR wrap unwrap work = beforeBiR (prefixFailMsg "worker/wrapper fusion failed: " $ do (_,tyB) <- wrapUnwrapTypes wrap unwrap guardMsg (exprType work `eqType` tyB) "type of worker does not match types of wrap/unwrap." ) (\ () -> bidirectional fusL fusR) where fusL :: RewriteH CoreExpr fusL = prefixFailMsg "worker/wrapper fusion failed: " $ withPatFailMsg (wrongExprForm "App unwrap (App wrap work)") $ do App unwrap' (App wrap' work') <- idR guardMsg (exprEqual wrap wrap') "given wrapper does not match wrapper in expression." guardMsg (exprEqual unwrap unwrap') "given unwrapper does not match unwrapper in expression." guardMsg (exprEqual work work') "given worker function does not worker in expression." return work fusR :: RewriteH CoreExpr fusR = prefixFailMsg "(reverse) worker/wrapper fusion failed: " $ do work' <- idR guardMsg (exprEqual work work') "given worker function does not match expression." return $ App unwrap (App wrap work) -- | Given @wrap :: b -> a@, @unwrap :: a -> b@ and @work :: b@ as arguments, then -- @unwrap (wrap work)@ \<==\> @work@ workerWrapperFusion :: CoreString -> CoreString -> CoreString -> BiRewriteH CoreExpr workerWrapperFusion = parse3beforeBiR workerWrapperFusionBR -------------------------------------------------------------------------------------------------- -- | \\ wrap unwrap -> (@g = expr@ ==> @g = let f = \\ g -> expr in let work = unwrap (f (wrap work)) in wrap work)@ workerWrapperSplitR :: CoreExpr -> CoreExpr -> RewriteH CoreDef workerWrapperSplitR wrap unwrap = let f = TH.mkName "f" w = TH.mkName "w" work = TH.mkName "work" fx = TH.mkName "fix" in fixIntro >>> defAllR idR ( appAllR idR (letIntro f) >>> letFloatArg >>> letAllR idR ( forewardT (workerWrapperFacBR wrap unwrap) >>> appAllR idR (letIntro w) >>> letFloatArg >>> letNonRecAllR idR (unfoldNameR fx >>> alphaLetWith [work] >>> extractR simplifyR) idR >>> letSubstR >>> letFloatArg ) ) -- | \\ wrap unwrap -> (@g = expr@ ==> @g = let f = \\ g -> expr in let work = unwrap (f (wrap work)) in wrap work)@ workerWrapperSplit :: CoreString -> CoreString -> RewriteH CoreDef workerWrapperSplit wrapS unwrapS = (parseCoreExprT wrapS &&& parseCoreExprT unwrapS) >>= uncurry workerWrapperSplitR -- | As 'workerWrapperSplit' but performs the static-argument transformation for @n@ type paramaters first, providing these types as arguments to all calls of wrap and unwrap. -- This is useful if the expression, and wrap and unwrap, all have a @forall@ type. workerWrapperSplitParam :: Int -> CoreString -> CoreString -> RewriteH CoreDef workerWrapperSplitParam 0 = workerWrapperSplit workerWrapperSplitParam n = \ wrapS unwrapS -> prefixFailMsg "worker/wrapper split (forall variant) failed: " $ do guardMsg (n == 1) "currently only supports 1 type paramater." withPatFailMsg "right-hand-side of definition does not have the form: Lam t e" $ do Def _ (Lam t _) <- idR guardMsg (isTyVar t) "first argument is not a type." let splitAtDefR :: RewriteH Core splitAtDefR = do p <- considerConstructT Definition pathR p $ promoteR $ do wrap <- parseCoreExprT wrapS unwrap <- parseCoreExprT unwrapS let ty = Type (TyVarTy t) workerWrapperSplitR (App wrap ty) (App unwrap ty) staticArg >>> extractR splitAtDefR -------------------------------------------------------------------------------------------------- -- | @wrap (unwrap x)@ \<==\> @x@ wwAssA :: CoreExpr -> CoreExpr -> BiRewriteH CoreExpr wwAssA wrap unwrap = beforeBiR (wrapUnwrapTypes wrap unwrap) (\ (tyA,_) -> bidirectional wwAL (wwAR tyA)) where wwAL :: RewriteH CoreExpr wwAL = withPatFailMsg (wrongExprForm "App wrap (App unwrap x)") $ do App wrap' (App unwrap' x) <- idR guardMsg (exprEqual wrap wrap') "given wrapper does not match wrapper in expression." guardMsg (exprEqual unwrap unwrap') "given unwrapper does not match unwrapper in expression." return x wwAR :: Type -> RewriteH CoreExpr wwAR tyA = do x <- idR guardMsg (exprType x `eqType` tyA) "type of expression does not match types of wrap/unwrap." return $ App wrap (App unwrap x) -- | @wrap (unwrap x)@ \<==\> @x@ wwA :: CoreString -> CoreString -> BiRewriteH CoreExpr wwA = parse2beforeBiR wwAssA -- | @wrap (unwrap (f x))@ \<==\> @f x@ wwAssB :: CoreExpr -> CoreExpr -> CoreExpr -> BiRewriteH CoreExpr wwAssB wrap unwrap f = bidirectional wwBL wwBR where assA :: BiRewriteH CoreExpr assA = wwAssA wrap unwrap wwBL :: RewriteH CoreExpr wwBL = withPatFailMsg (wrongExprForm "App wrap (App unwrap (App f x))") $ do App _ (App _ (App f' _)) <- idR guardMsg (exprEqual f f') "given body function does not match expression." forewardT assA wwBR :: RewriteH CoreExpr wwBR = withPatFailMsg (wrongExprForm "App f x") $ do App f' _ <- idR guardMsg (exprEqual f f') "given body function does not match expression." backwardT assA -- | @wrap (unwrap (f x))@ \<==\> @f x@ wwB :: CoreString -> CoreString -> CoreString -> BiRewriteH CoreExpr wwB = parse3beforeBiR wwAssB -- | @fix t (\ x -> wrap (unwrap (f x)))@ \<==\> @fix t f@ wwAssC :: CoreExpr -> CoreExpr -> CoreExpr -> BiRewriteH CoreExpr wwAssC wrap unwrap f = beforeBiR isFixExpr (\ _ -> bidirectional wwCL wwCR) where assB :: BiRewriteH CoreExpr assB = wwAssB wrap unwrap f wwCL :: RewriteH CoreExpr wwCL = appAllR idR (lamAllR idR (forewardT assB) >>> etaReduce) wwCR :: RewriteH CoreExpr wwCR = appAllR idR (etaExpand "x" >>> lamAllR idR (backwardT assB)) -- | @fix t (\ x -> wrap (unwrap (f x)))@ \<==\> @fix t f@ wwC :: CoreString -> CoreString -> CoreString -> BiRewriteH CoreExpr wwC = parse3beforeBiR wwAssC -------------------------------------------------------------------------------------------------- -- | Check that the expression has the form "fix t (f :: t -> t)", returning "t" and "f". isFixExpr :: TranslateH CoreExpr (Type,CoreExpr) isFixExpr = withPatFailMsg (wrongExprForm "fix t f") $ -- fix :: forall a. (a -> a) -> a do App (App (Var fixId) (Type ty)) f <- idR fixId' <- findFixId guardMsg (fixId == fixId') (var2String fixId ++ " does not match " ++ fixLocation) return (ty,f) wrapUnwrapTypes :: MonadCatch m => CoreExpr -> CoreExpr -> m (Type,Type) wrapUnwrapTypes wrap unwrap = setFailMsg "given expressions have the wrong types to form a valid wrap/unwrap pair." $ funsWithInverseTypes unwrap wrap -------------------------------------------------------------------------------------------------- -- | f ==> fix f mkFix :: CoreExpr -> TranslateH z CoreExpr mkFix f = do t <- endoFunType f fixId <- findFixId return $ mkCoreApps (varToCoreExpr fixId) [Type t, f] fixLocation :: String fixLocation = "Data.Function.fix" findFixId :: TranslateH a Id findFixId = findIdT (TH.mkName fixLocation) -------------------------------------------------------------------------------------------------- parse2beforeBiR :: (CoreExpr -> CoreExpr -> BiRewriteH a) -> CoreString -> CoreString -> BiRewriteH a parse2beforeBiR f s1 s2 = beforeBiR (parseCoreExprT s1 &&& parseCoreExprT s2) (uncurry f) parse3beforeBiR :: (CoreExpr -> CoreExpr -> CoreExpr -> BiRewriteH a) -> CoreString -> CoreString -> CoreString -> BiRewriteH a parse3beforeBiR f s1 s2 s3 = beforeBiR ((parseCoreExprT s1 &&& parseCoreExprT s2) &&& parseCoreExprT s3) ((uncurry.uncurry) f) --------------------------------------------------------------------------------------------------