module Language.HERMIT.Primitive.FixPoint
(
Language.HERMIT.Primitive.FixPoint.externals
, fixIntro
, fixComputationRule
, rollingRule
)
where
import GhcPlugins as GHC hiding (varName)
import Control.Applicative
import Control.Arrow
import Data.Monoid (mempty)
import Language.HERMIT.Core
import Language.HERMIT.Monad
import Language.HERMIT.Kure
import Language.HERMIT.External
import Language.HERMIT.GHC
import Language.HERMIT.Primitive.AlphaConversion
import Language.HERMIT.Primitive.Common
import Language.HERMIT.Primitive.GHC
import Language.HERMIT.Primitive.Local
import Language.HERMIT.Primitive.Navigation
import Language.HERMIT.Primitive.New
import Language.HERMIT.Primitive.Unfold
import qualified Language.Haskell.TH as TH
externals :: [External]
externals = map (.+ Experiment)
[ external "fix-intro" (promoteDefR fixIntro :: RewriteH Core)
[ "rewrite a recursive binding into a non-recursive binding using fix"
] .+ Introduce .+ Context
, external "fix-computation" ((promoteExprBiR fixComputationRule) :: BiRewriteH Core)
[ "Fixed-Point Computation Rule",
"fix t f <==> f (fix t f)"
]
, external "rolling-rule" ((promoteExprBiR rollingRule) :: BiRewriteH Core)
[ "Rolling Rule",
"fix tyA (\\ a -> f (g a)) <==> f (fix tyB (\\ b -> g (f b))"
]
, external "ww-factorisation" ((\ wrap unwrap -> promoteExprBiR $ workerWrapperFac wrap unwrap) :: CoreString -> CoreString -> BiRewriteH Core)
[ "Worker/Wrapper Factorisation",
"For any \"f :: a -> a\", and given \"wrap :: b -> a\" and \"unwrap :: a -> b\" as arguments, then",
"fix tyA f <==> wrap (fix tyB (\\ b -> unwrap (f (wrap b))))",
"Note: the pre-condition \"fix tyA (\\ a -> wrap (unwrap (f a))) == fix tyA f\" is expected to hold."
] .+ Introduce .+ Context .+ PreCondition
, external "ww-fusion" ((\ wrap unwrap work -> promoteExprBiR $ workerWrapperFusion wrap unwrap work) :: CoreString -> CoreString -> CoreString -> BiRewriteH Core)
[ "Worker/Wrapper Fusion",
"Given \"wrap :: b -> a\", \"unwrap :: a -> b\" and \"work :: b\" as arguments, then",
"unwrap (wrap work) <==> work",
"Note: the pre-conditions \"fix tyA (\\ a -> wrap (unwrap (f a))) == fix tyA f\"",
" and \"work == fix (\\ b -> unwrap (f (wrap)))\" are expected to hold."
] .+ Introduce .+ Context .+ PreCondition
, external "ww-split" ((\ wrap unwrap -> promoteDefR $ workerWrapperSplit wrap unwrap) :: CoreString -> CoreString -> RewriteH Core)
[ "Worker/Wrapper Split",
"For any \"g :: a\", and given \"wrap :: b -> a\" and \"unwrap :: a -> b\" as arguments, then",
"g = expr ==> g = let f = \\ g -> expr",
" in let work = unwrap (f (wrap work))",
" in wrap work",
"Note: the pre-condition \"fix a (wrap . unwrap . f) == fix a f\" is expected to hold."
] .+ Introduce .+ Context .+ PreCondition
, external "ww-split-param" ((\ n wrap unwrap -> promoteDefR $ workerWrapperSplitParam n wrap unwrap) :: Int -> CoreString -> CoreString -> RewriteH Core)
[ "Worker/Wrapper Split - Type Paramater Variant",
"For any \"g :: forall t1 t2 .. tn . a\", and given \"wrap :: forall t1 t2 .. tn . b -> a\" and \"unwrap :: forall t1 t2 .. tn . a -> b\" as arguments, then",
"g = expr ==> g = \\ t1 t2 .. tn -> let f = \\ g -> expr t1 t2 .. tn",
" in let work = unwrap t1 t2 .. tn (f (wrap t1 t2 ..tn work))",
" in wrap t1 t2 .. tn work"
] .+ Introduce .+ Context .+ PreCondition .+ TODO .+ Experiment
, external "ww-assumption-A" ((\ wrap unwrap -> promoteExprBiR $ wwA wrap unwrap) :: CoreString -> CoreString -> BiRewriteH Core)
[ "Worker/Wrapper Assumption A",
"For a \"wrap :: b -> a\" and an \"unwrap :: b -> a\", then",
"wrap (unwrap x) <==> x",
"Note: only use this if it's true!"
] .+ Context .+ PreCondition
, external "ww-assumption-B" ((\ wrap unwrap f -> promoteExprBiR $ wwB wrap unwrap f) :: CoreString -> CoreString -> CoreString -> BiRewriteH Core)
[ "Worker/Wrapper Assumption B",
"For a \"wrap :: b -> a\", an \"unwrap :: b -> a\", and an \"f :: a -> a\" then",
"wrap (unwrap (f x)) <==> f x",
"Note: only use this if it's true!"
] .+ Context .+ PreCondition
, external "ww-assumption-C" ((\ wrap unwrap f -> promoteExprBiR $ wwC wrap unwrap f) :: CoreString -> CoreString -> CoreString -> BiRewriteH Core)
[ "Worker/Wrapper Assumption C",
"For a \"wrap :: b -> a\", an \"unwrap :: b -> a\", and an \"f :: a -> a\" then",
"fix t (\\ x -> wrap (unwrap (f x))) <==> fix t f",
"Note: only use this if it's true!"
] .+ Context .+ PreCondition
]
fixIntro :: RewriteH CoreDef
fixIntro = prefixFailMsg "fix introduction failed: " $
do Def f _ <- idR
f' <- constT $ cloneVarH id f
Def f <$> (mkFix =<< (defT mempty (extractR $ substR f $ varToCoreExpr f') (\ () e' -> Lam f' e')))
fixComputationRule :: BiRewriteH CoreExpr
fixComputationRule = bidirectional computationL computationR
where
computationL :: RewriteH CoreExpr
computationL = prefixFailMsg "fix computation rule failed: " $
do (_,f) <- isFixExpr
fixf <- idR
return (App f fixf)
computationR :: RewriteH CoreExpr
computationR = prefixFailMsg "fix computation rule failed: " $
do App f fixf <- idR
(_,f') <- isFixExpr <<< constant fixf
guardMsg (exprEqual f f') "external function does not match internal expression"
return fixf
rollingRule :: BiRewriteH CoreExpr
rollingRule = bidirectional rollingRuleL rollingRuleR
where
rollingRuleL :: RewriteH CoreExpr
rollingRuleL = prefixFailMsg "rolling rule failed: " $
withPatFailMsg wrongFixBody $
do (tyA, Lam a (App f (App g (Var a')))) <- isFixExpr
guardMsg (a == a') wrongFixBody
(tyA',tyB) <- funsWithInverseTypes g f
guardMsg (eqType tyA tyA') "Type mismatch: this shouldn't have happened, report this as a bug."
res <- rollingRuleResult tyB g f
return (App f res)
rollingRuleR :: RewriteH CoreExpr
rollingRuleR = prefixFailMsg "(reversed) rolling rule failed: " $
withPatFailMsg "not an application." $
do App f fx <- idR
withPatFailMsg wrongFixBody $
do (tyB, Lam b (App g (App f' (Var b')))) <- isFixExpr <<< constant fx
guardMsg (b == b') wrongFixBody
guardMsg (exprEqual f f') "external function does not match internal expression"
(tyA,tyB') <- funsWithInverseTypes g f
guardMsg (eqType tyB tyB') "Type mismatch: this shouldn't have happened, report this as a bug."
rollingRuleResult tyA f g
rollingRuleResult :: Type -> CoreExpr -> CoreExpr -> TranslateH z CoreExpr
rollingRuleResult ty f g = do x <- constT (newIdH "x" ty)
mkFix (Lam x (App f (App g (Var x))))
wrongFixBody :: String
wrongFixBody = "body of fix does not have the form: Lam v (App f (App g (Var v)))"
workerWrapperFacBR :: CoreExpr -> CoreExpr -> BiRewriteH CoreExpr
workerWrapperFacBR wrap unwrap = beforeBiR (wrapUnwrapTypes wrap unwrap)
(\ (tyA,tyB) -> bidirectional (wwL tyA tyB) wwR)
where
wwL :: Type -> Type -> RewriteH CoreExpr
wwL tyA tyB = prefixFailMsg "worker/wrapper factorisation failed: " $
do (tA,f) <- isFixExpr
guardMsg (eqType tyA tA) ("wrapper/unwrapper types do not match fix body type.")
b <- constT (newIdH "x" tyB)
fx <- mkFix (Lam b (App unwrap (App f (App wrap (Var b)))))
return $ App wrap fx
wwR :: RewriteH CoreExpr
wwR = prefixFailMsg "(reverse) worker/wrapper factorisation failed: " $
withPatFailMsg "not an application." $
do App wrap2 fx <- idR
withPatFailMsg wrongFixBody $
do (_, Lam b (App unwrap1 (App f (App wrap1 (Var b'))))) <- isFixExpr <<< constant fx
guardMsg (b == b') wrongFixBody
guardMsg (exprEqual wrap wrap2) "given wrapper does not match applied function."
guardMsg (exprEqual wrap wrap1) "given wrapper does not match wrapper in body of fix."
guardMsg (exprEqual unwrap unwrap1) "given unwrapper does not match unwrapper in body of fix."
mkFix f
wrongFixBody :: String
wrongFixBody = "body of fix does not have the form Lam b (App unwrap (App f (App wrap (Var b))))"
workerWrapperFac :: CoreString -> CoreString -> BiRewriteH CoreExpr
workerWrapperFac = parse2beforeBiR workerWrapperFacBR
workerWrapperFusionBR :: CoreExpr -> CoreExpr -> CoreExpr -> BiRewriteH CoreExpr
workerWrapperFusionBR wrap unwrap work = beforeBiR (prefixFailMsg "worker/wrapper fusion failed: " $
do (_,tyB) <- wrapUnwrapTypes wrap unwrap
guardMsg (exprType work `eqType` tyB) "type of worker does not match types of wrap/unwrap."
)
(\ () -> bidirectional fusL fusR)
where
fusL :: RewriteH CoreExpr
fusL = prefixFailMsg "worker/wrapper fusion failed: " $
withPatFailMsg (wrongExprForm "App unwrap (App wrap work)") $
do App unwrap' (App wrap' work') <- idR
guardMsg (exprEqual wrap wrap') "given wrapper does not match wrapper in expression."
guardMsg (exprEqual unwrap unwrap') "given unwrapper does not match unwrapper in expression."
guardMsg (exprEqual work work') "given worker function does not worker in expression."
return work
fusR :: RewriteH CoreExpr
fusR = prefixFailMsg "(reverse) worker/wrapper fusion failed: " $
do work' <- idR
guardMsg (exprEqual work work') "given worker function does not match expression."
return $ App unwrap (App wrap work)
workerWrapperFusion :: CoreString -> CoreString -> CoreString -> BiRewriteH CoreExpr
workerWrapperFusion = parse3beforeBiR workerWrapperFusionBR
workerWrapperSplitR :: CoreExpr -> CoreExpr -> RewriteH CoreDef
workerWrapperSplitR wrap unwrap =
let f = TH.mkName "f"
w = TH.mkName "w"
work = TH.mkName "work"
fx = TH.mkName "fix"
in
fixIntro >>> defAllR idR ( appAllR idR (letIntro f)
>>> letFloatArg
>>> letAllR idR ( forewardT (workerWrapperFacBR wrap unwrap)
>>> appAllR idR (letIntro w)
>>> letFloatArg
>>> letNonRecAllR idR (unfoldNameR fx >>> alphaLetWith [work] >>> extractR simplifyR) idR
>>> letSubstR
>>> letFloatArg
)
)
workerWrapperSplit :: CoreString -> CoreString -> RewriteH CoreDef
workerWrapperSplit wrapS unwrapS = (parseCoreExprT wrapS &&& parseCoreExprT unwrapS) >>= uncurry workerWrapperSplitR
workerWrapperSplitParam :: Int -> CoreString -> CoreString -> RewriteH CoreDef
workerWrapperSplitParam 0 = workerWrapperSplit
workerWrapperSplitParam n = \ wrapS unwrapS -> prefixFailMsg "worker/wrapper split (forall variant) failed: " $
do guardMsg (n == 1) "currently only supports 1 type paramater."
withPatFailMsg "right-hand-side of definition does not have the form: Lam t e" $
do Def _ (Lam t _) <- idR
guardMsg (isTyVar t) "first argument is not a type."
let splitAtDefR :: RewriteH Core
splitAtDefR = do p <- considerConstructT Definition
pathR p $ promoteR $ do wrap <- parseCoreExprT wrapS
unwrap <- parseCoreExprT unwrapS
let ty = Type (TyVarTy t)
workerWrapperSplitR (App wrap ty) (App unwrap ty)
staticArg >>> extractR splitAtDefR
wwAssA :: CoreExpr -> CoreExpr -> BiRewriteH CoreExpr
wwAssA wrap unwrap = beforeBiR (wrapUnwrapTypes wrap unwrap) (\ (tyA,_) -> bidirectional wwAL (wwAR tyA))
where
wwAL :: RewriteH CoreExpr
wwAL = withPatFailMsg (wrongExprForm "App wrap (App unwrap x)") $
do App wrap' (App unwrap' x) <- idR
guardMsg (exprEqual wrap wrap') "given wrapper does not match wrapper in expression."
guardMsg (exprEqual unwrap unwrap') "given unwrapper does not match unwrapper in expression."
return x
wwAR :: Type -> RewriteH CoreExpr
wwAR tyA = do x <- idR
guardMsg (exprType x `eqType` tyA) "type of expression does not match types of wrap/unwrap."
return $ App wrap (App unwrap x)
wwA :: CoreString -> CoreString -> BiRewriteH CoreExpr
wwA = parse2beforeBiR wwAssA
wwAssB :: CoreExpr -> CoreExpr -> CoreExpr -> BiRewriteH CoreExpr
wwAssB wrap unwrap f = bidirectional wwBL wwBR
where
assA :: BiRewriteH CoreExpr
assA = wwAssA wrap unwrap
wwBL :: RewriteH CoreExpr
wwBL = withPatFailMsg (wrongExprForm "App wrap (App unwrap (App f x))") $
do App _ (App _ (App f' _)) <- idR
guardMsg (exprEqual f f') "given body function does not match expression."
forewardT assA
wwBR :: RewriteH CoreExpr
wwBR = withPatFailMsg (wrongExprForm "App f x") $
do App f' _ <- idR
guardMsg (exprEqual f f') "given body function does not match expression."
backwardT assA
wwB :: CoreString -> CoreString -> CoreString -> BiRewriteH CoreExpr
wwB = parse3beforeBiR wwAssB
wwAssC :: CoreExpr -> CoreExpr -> CoreExpr -> BiRewriteH CoreExpr
wwAssC wrap unwrap f = beforeBiR isFixExpr (\ _ -> bidirectional wwCL wwCR)
where
assB :: BiRewriteH CoreExpr
assB = wwAssB wrap unwrap f
wwCL :: RewriteH CoreExpr
wwCL = appAllR idR (lamAllR idR (forewardT assB) >>> etaReduce)
wwCR :: RewriteH CoreExpr
wwCR = appAllR idR (etaExpand "x" >>> lamAllR idR (backwardT assB))
wwC :: CoreString -> CoreString -> CoreString -> BiRewriteH CoreExpr
wwC = parse3beforeBiR wwAssC
isFixExpr :: TranslateH CoreExpr (Type,CoreExpr)
isFixExpr = withPatFailMsg (wrongExprForm "fix t f") $
do App (App (Var fixId) (Type ty)) f <- idR
fixId' <- findFixId
guardMsg (fixId == fixId') (var2String fixId ++ " does not match " ++ fixLocation)
return (ty,f)
wrapUnwrapTypes :: MonadCatch m => CoreExpr -> CoreExpr -> m (Type,Type)
wrapUnwrapTypes wrap unwrap = setFailMsg "given expressions have the wrong types to form a valid wrap/unwrap pair." $
funsWithInverseTypes unwrap wrap
mkFix :: CoreExpr -> TranslateH z CoreExpr
mkFix f = do t <- endoFunType f
fixId <- findFixId
return $ mkCoreApps (varToCoreExpr fixId) [Type t, f]
fixLocation :: String
fixLocation = "Data.Function.fix"
findFixId :: TranslateH a Id
findFixId = findIdT (TH.mkName fixLocation)
parse2beforeBiR :: (CoreExpr -> CoreExpr -> BiRewriteH a) -> CoreString -> CoreString -> BiRewriteH a
parse2beforeBiR f s1 s2 = beforeBiR (parseCoreExprT s1 &&& parseCoreExprT s2) (uncurry f)
parse3beforeBiR :: (CoreExpr -> CoreExpr -> CoreExpr -> BiRewriteH a) -> CoreString -> CoreString -> CoreString -> BiRewriteH a
parse3beforeBiR f s1 s2 s3 = beforeBiR ((parseCoreExprT s1 &&& parseCoreExprT s2) &&& parseCoreExprT s3) ((uncurry.uncurry) f)