-- | This module implements the translation from the multi-tick
-- calculus to the single tick calculus.

{-# LANGUAGE CPP #-}

module Rattus.Plugin.SingleTick
  (toSingleTick) where

#if __GLASGOW_HASKELL__ >= 900
import GHC.Plugins
#else
import GhcPlugins
#endif

  
import Rattus.Plugin.Utils
import Prelude hiding ((<>))
import Control.Monad.Trans.Writer.Strict
import Control.Monad.Trans.Class
import Data.List

-- | Transform the given expression from the multi-tick calculus into
-- the single tick calculus form.
toSingleTick :: CoreExpr -> CoreM CoreExpr
toSingleTick :: CoreExpr -> CoreM CoreExpr
toSingleTick (Let (Rec [(CoreBndr, CoreExpr)]
bs) CoreExpr
e) = do
  CoreExpr
e' <- CoreExpr -> CoreM CoreExpr
toSingleTick CoreExpr
e
  [(CoreBndr, CoreExpr)]
bs' <- ((CoreBndr, CoreExpr) -> CoreM (CoreBndr, CoreExpr))
-> [(CoreBndr, CoreExpr)] -> CoreM [(CoreBndr, CoreExpr)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((CoreExpr -> CoreM CoreExpr)
-> (CoreBndr, CoreExpr) -> CoreM (CoreBndr, CoreExpr)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM CoreExpr -> CoreM CoreExpr
toSingleTick) [(CoreBndr, CoreExpr)]
bs
  CoreExpr -> CoreM CoreExpr
forall (m :: * -> *) a. Monad m => a -> m a
return (Bind CoreBndr -> CoreExpr -> CoreExpr
forall b. Bind b -> Expr b -> Expr b
Let ([(CoreBndr, CoreExpr)] -> Bind CoreBndr
forall b. [(b, Expr b)] -> Bind b
Rec [(CoreBndr, CoreExpr)]
bs') CoreExpr
e')
toSingleTick (Let (NonRec CoreBndr
b CoreExpr
e1) CoreExpr
e2) = do
  CoreExpr
e1' <- CoreExpr -> CoreM CoreExpr
toSingleTick CoreExpr
e1
  CoreExpr
e2' <- CoreExpr -> CoreM CoreExpr
toSingleTick CoreExpr
e2
  CoreExpr -> CoreM CoreExpr
forall (m :: * -> *) a. Monad m => a -> m a
return (Bind CoreBndr -> CoreExpr -> CoreExpr
forall b. Bind b -> Expr b -> Expr b
Let (CoreBndr -> CoreExpr -> Bind CoreBndr
forall b. b -> Expr b -> Bind b
NonRec CoreBndr
b CoreExpr
e1') CoreExpr
e2')
toSingleTick (Case CoreExpr
e CoreBndr
b Type
ty [Alt CoreBndr]
alts) = do
  CoreExpr
e' <- CoreExpr -> CoreM CoreExpr
toSingleTick CoreExpr
e
  [Alt CoreBndr]
alts' <- (Alt CoreBndr -> CoreM (Alt CoreBndr))
-> [Alt CoreBndr] -> CoreM [Alt CoreBndr]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\ (AltCon
c,[CoreBndr]
bs,CoreExpr
f) -> (CoreExpr -> Alt CoreBndr)
-> CoreM CoreExpr -> CoreM (Alt CoreBndr)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\ CoreExpr
x ->(AltCon
c,[CoreBndr]
bs,CoreExpr
x)) (CoreExpr -> CoreM CoreExpr
toSingleTick CoreExpr
f)) [Alt CoreBndr]
alts
  CoreExpr -> CoreM CoreExpr
forall (m :: * -> *) a. Monad m => a -> m a
return (CoreExpr -> CoreBndr -> Type -> [Alt CoreBndr] -> CoreExpr
forall b. Expr b -> b -> Type -> [Alt b] -> Expr b
Case CoreExpr
e' CoreBndr
b Type
ty [Alt CoreBndr]
alts')
toSingleTick (Cast CoreExpr
e Coercion
c) = do
  CoreExpr
e' <- CoreExpr -> CoreM CoreExpr
toSingleTick CoreExpr
e
  CoreExpr -> CoreM CoreExpr
forall (m :: * -> *) a. Monad m => a -> m a
return (CoreExpr -> Coercion -> CoreExpr
forall b. Expr b -> Coercion -> Expr b
Cast CoreExpr
e' Coercion
c)
toSingleTick (Tick Tickish CoreBndr
t CoreExpr
e) = do
  CoreExpr
e' <- CoreExpr -> CoreM CoreExpr
toSingleTick CoreExpr
e
  CoreExpr -> CoreM CoreExpr
forall (m :: * -> *) a. Monad m => a -> m a
return (Tickish CoreBndr -> CoreExpr -> CoreExpr
forall b. Tickish CoreBndr -> Expr b -> Expr b
Tick Tickish CoreBndr
t CoreExpr
e')
toSingleTick (Lam CoreBndr
x CoreExpr
e) = do
  (CoreExpr
e', [(CoreBndr, CoreExpr, CoreExpr)]
advs) <- WriterT [(CoreBndr, CoreExpr, CoreExpr)] CoreM CoreExpr
-> CoreM (CoreExpr, [(CoreBndr, CoreExpr, CoreExpr)])
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (CoreExpr -> WriterT [(CoreBndr, CoreExpr, CoreExpr)] CoreM CoreExpr
extractAdv' CoreExpr
e)
  [(CoreBndr, CoreExpr, CoreExpr)]
advs' <- ((CoreBndr, CoreExpr, CoreExpr)
 -> CoreM (CoreBndr, CoreExpr, CoreExpr))
-> [(CoreBndr, CoreExpr, CoreExpr)]
-> CoreM [(CoreBndr, CoreExpr, CoreExpr)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\ (CoreBndr
x,CoreExpr
a,CoreExpr
b) -> (CoreExpr -> (CoreBndr, CoreExpr, CoreExpr))
-> CoreM CoreExpr -> CoreM (CoreBndr, CoreExpr, CoreExpr)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\CoreExpr
b' -> (CoreBndr
x,CoreExpr
a,CoreExpr
b')) (CoreExpr -> CoreM CoreExpr
toSingleTick CoreExpr
b)) [(CoreBndr, CoreExpr, CoreExpr)]
advs
  CoreExpr -> CoreM CoreExpr
forall (m :: * -> *) a. Monad m => a -> m a
return ([(CoreBndr, CoreExpr, CoreExpr)] -> CoreExpr -> CoreExpr
foldLets' [(CoreBndr, CoreExpr, CoreExpr)]
advs' (CoreBndr -> CoreExpr -> CoreExpr
forall b. b -> Expr b -> Expr b
Lam CoreBndr
x CoreExpr
e'))
toSingleTick (App CoreExpr
e1 CoreExpr
e2)
  | CoreExpr -> Bool
isDelayApp CoreExpr
e1 = do
      (CoreExpr
e2', [(CoreBndr, CoreExpr)]
advs) <- WriterT [(CoreBndr, CoreExpr)] CoreM CoreExpr
-> CoreM (CoreExpr, [(CoreBndr, CoreExpr)])
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (CoreExpr -> WriterT [(CoreBndr, CoreExpr)] CoreM CoreExpr
extractAdv CoreExpr
e2)
      [(CoreBndr, CoreExpr)]
advs' <- ((CoreBndr, CoreExpr) -> CoreM (CoreBndr, CoreExpr))
-> [(CoreBndr, CoreExpr)] -> CoreM [(CoreBndr, CoreExpr)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((CoreExpr -> CoreM CoreExpr)
-> (CoreBndr, CoreExpr) -> CoreM (CoreBndr, CoreExpr)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM CoreExpr -> CoreM CoreExpr
toSingleTick) [(CoreBndr, CoreExpr)]
advs
      CoreExpr -> CoreM CoreExpr
forall (m :: * -> *) a. Monad m => a -> m a
return ([(CoreBndr, CoreExpr)] -> CoreExpr -> CoreExpr
foldLets [(CoreBndr, CoreExpr)]
advs' (CoreExpr -> CoreExpr -> CoreExpr
forall b. Expr b -> Expr b -> Expr b
App CoreExpr
e1 CoreExpr
e2'))
  | Bool
otherwise = do
      CoreExpr
e1' <- CoreExpr -> CoreM CoreExpr
toSingleTick CoreExpr
e1
      CoreExpr
e2' <- CoreExpr -> CoreM CoreExpr
toSingleTick CoreExpr
e2
      CoreExpr -> CoreM CoreExpr
forall (m :: * -> *) a. Monad m => a -> m a
return (CoreExpr -> CoreExpr -> CoreExpr
forall b. Expr b -> Expr b -> Expr b
App CoreExpr
e1' CoreExpr
e2')

toSingleTick e :: CoreExpr
e@Type{} = CoreExpr -> CoreM CoreExpr
forall (m :: * -> *) a. Monad m => a -> m a
return CoreExpr
e
toSingleTick e :: CoreExpr
e@Var{} = CoreExpr -> CoreM CoreExpr
forall (m :: * -> *) a. Monad m => a -> m a
return CoreExpr
e
toSingleTick e :: CoreExpr
e@Lit{} = CoreExpr -> CoreM CoreExpr
forall (m :: * -> *) a. Monad m => a -> m a
return CoreExpr
e
toSingleTick e :: CoreExpr
e@Coercion{} = CoreExpr -> CoreM CoreExpr
forall (m :: * -> *) a. Monad m => a -> m a
return CoreExpr
e

foldLets :: [(Id,CoreExpr)] -> CoreExpr -> CoreExpr
foldLets :: [(CoreBndr, CoreExpr)] -> CoreExpr -> CoreExpr
foldLets [(CoreBndr, CoreExpr)]
ls CoreExpr
e = (CoreExpr -> (CoreBndr, CoreExpr) -> CoreExpr)
-> CoreExpr -> [(CoreBndr, CoreExpr)] -> CoreExpr
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (\CoreExpr
e' (CoreBndr
x,CoreExpr
b) -> Bind CoreBndr -> CoreExpr -> CoreExpr
forall b. Bind b -> Expr b -> Expr b
Let (CoreBndr -> CoreExpr -> Bind CoreBndr
forall b. b -> Expr b -> Bind b
NonRec CoreBndr
x CoreExpr
b) CoreExpr
e') CoreExpr
e [(CoreBndr, CoreExpr)]
ls

foldLets' :: [(Id,CoreExpr,CoreExpr)] -> CoreExpr -> CoreExpr
foldLets' :: [(CoreBndr, CoreExpr, CoreExpr)] -> CoreExpr -> CoreExpr
foldLets' [(CoreBndr, CoreExpr, CoreExpr)]
ls CoreExpr
e = (CoreExpr -> (CoreBndr, CoreExpr, CoreExpr) -> CoreExpr)
-> CoreExpr -> [(CoreBndr, CoreExpr, CoreExpr)] -> CoreExpr
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (\CoreExpr
e' (CoreBndr
x,CoreExpr
a,CoreExpr
b) -> Bind CoreBndr -> CoreExpr -> CoreExpr
forall b. Bind b -> Expr b -> Expr b
Let (CoreBndr -> CoreExpr -> Bind CoreBndr
forall b. b -> Expr b -> Bind b
NonRec CoreBndr
x (CoreExpr -> CoreExpr -> CoreExpr
forall b. Expr b -> Expr b -> Expr b
App CoreExpr
a CoreExpr
b)) CoreExpr
e') CoreExpr
e [(CoreBndr, CoreExpr, CoreExpr)]
ls

isVar :: CoreExpr -> Bool
isVar :: CoreExpr -> Bool
isVar (App CoreExpr
e CoreExpr
e')
  | CoreExpr -> Bool
forall b. Expr b -> Bool
isType CoreExpr
e' Bool -> Bool -> Bool
|| Bool -> Bool
not  (Type -> Bool
tcIsLiftedTypeKind(HasDebugCallStack => Type -> Type
Type -> Type
typeKind (CoreExpr -> Type
exprType CoreExpr
e'))) = CoreExpr -> Bool
isVar CoreExpr
e
  | Bool
otherwise = Bool
False
isVar (Cast CoreExpr
e Coercion
_) = CoreExpr -> Bool
isVar CoreExpr
e
isVar (Tick Tickish CoreBndr
_ CoreExpr
e) = CoreExpr -> Bool
isVar CoreExpr
e
isVar (Var CoreBndr
_) = Bool
True
isVar CoreExpr
_ = Bool
False


extractAdvApp :: CoreExpr -> CoreExpr -> WriterT [(Id,CoreExpr)] CoreM CoreExpr
extractAdvApp :: CoreExpr
-> CoreExpr -> WriterT [(CoreBndr, CoreExpr)] CoreM CoreExpr
extractAdvApp CoreExpr
e1 CoreExpr
e2
  | CoreExpr -> Bool
isVar CoreExpr
e2 = CoreExpr -> WriterT [(CoreBndr, CoreExpr)] CoreM CoreExpr
forall (m :: * -> *) a. Monad m => a -> m a
return (CoreExpr -> CoreExpr -> CoreExpr
forall b. Expr b -> Expr b -> Expr b
App CoreExpr
e1 CoreExpr
e2)
  | Bool
otherwise = do
  CoreBndr
x <- CoreM CoreBndr -> WriterT [(CoreBndr, CoreExpr)] CoreM CoreBndr
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (FastString -> CoreExpr -> CoreM CoreBndr
forall (m :: * -> *).
MonadUnique m =>
FastString -> CoreExpr -> m CoreBndr
mkSysLocalFromExpr (String -> FastString
fsLit String
"adv") CoreExpr
e2)
  [(CoreBndr, CoreExpr)] -> WriterT [(CoreBndr, CoreExpr)] CoreM ()
forall (m :: * -> *) w. Monad m => w -> WriterT w m ()
tell [(CoreBndr
x,CoreExpr
e2)]
  CoreExpr -> WriterT [(CoreBndr, CoreExpr)] CoreM CoreExpr
forall (m :: * -> *) a. Monad m => a -> m a
return (CoreExpr -> CoreExpr -> CoreExpr
forall b. Expr b -> Expr b -> Expr b
App CoreExpr
e1 (CoreBndr -> CoreExpr
forall b. CoreBndr -> Expr b
Var CoreBndr
x))

-- This is used to pull adv out of delayed terms. The writer monad
-- returns mappings from fresh variables to terms that occur as
-- argument of adv.
-- 
-- That is, occurrences of @adv t@ are replaced with @adv x@ (for some
-- fresh variable @x@) and the pair @(x,t)@ is returned in the
-- writer monad.
extractAdv :: CoreExpr -> WriterT [(Id,CoreExpr)] CoreM CoreExpr
extractAdv :: CoreExpr -> WriterT [(CoreBndr, CoreExpr)] CoreM CoreExpr
extractAdv e :: CoreExpr
e@(App CoreExpr
e1 CoreExpr
e2)
  | CoreExpr -> Bool
isAdvApp CoreExpr
e1 = CoreExpr
-> CoreExpr -> WriterT [(CoreBndr, CoreExpr)] CoreM CoreExpr
extractAdvApp CoreExpr
e1 CoreExpr
e2
  | CoreExpr -> Bool
isDelayApp CoreExpr
e1 = do
      (CoreExpr
e2', [(CoreBndr, CoreExpr)]
advs) <- CoreM (CoreExpr, [(CoreBndr, CoreExpr)])
-> WriterT
     [(CoreBndr, CoreExpr)] CoreM (CoreExpr, [(CoreBndr, CoreExpr)])
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (CoreM (CoreExpr, [(CoreBndr, CoreExpr)])
 -> WriterT
      [(CoreBndr, CoreExpr)] CoreM (CoreExpr, [(CoreBndr, CoreExpr)]))
-> CoreM (CoreExpr, [(CoreBndr, CoreExpr)])
-> WriterT
     [(CoreBndr, CoreExpr)] CoreM (CoreExpr, [(CoreBndr, CoreExpr)])
forall a b. (a -> b) -> a -> b
$ WriterT [(CoreBndr, CoreExpr)] CoreM CoreExpr
-> CoreM (CoreExpr, [(CoreBndr, CoreExpr)])
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (CoreExpr -> WriterT [(CoreBndr, CoreExpr)] CoreM CoreExpr
extractAdv CoreExpr
e2)
      [(CoreBndr, CoreExpr)]
advs' <- ((CoreBndr, CoreExpr)
 -> WriterT [(CoreBndr, CoreExpr)] CoreM (CoreBndr, CoreExpr))
-> [(CoreBndr, CoreExpr)]
-> WriterT [(CoreBndr, CoreExpr)] CoreM [(CoreBndr, CoreExpr)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((CoreExpr -> WriterT [(CoreBndr, CoreExpr)] CoreM CoreExpr)
-> (CoreBndr, CoreExpr)
-> WriterT [(CoreBndr, CoreExpr)] CoreM (CoreBndr, CoreExpr)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM CoreExpr -> WriterT [(CoreBndr, CoreExpr)] CoreM CoreExpr
extractAdv) [(CoreBndr, CoreExpr)]
advs
      CoreExpr -> WriterT [(CoreBndr, CoreExpr)] CoreM CoreExpr
forall (m :: * -> *) a. Monad m => a -> m a
return ([(CoreBndr, CoreExpr)] -> CoreExpr -> CoreExpr
foldLets [(CoreBndr, CoreExpr)]
advs' (CoreExpr -> CoreExpr -> CoreExpr
forall b. Expr b -> Expr b -> Expr b
App CoreExpr
e1 CoreExpr
e2'))
  | CoreExpr -> Bool
isBoxApp CoreExpr
e1 = CoreM CoreExpr -> WriterT [(CoreBndr, CoreExpr)] CoreM CoreExpr
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (CoreM CoreExpr -> WriterT [(CoreBndr, CoreExpr)] CoreM CoreExpr)
-> CoreM CoreExpr -> WriterT [(CoreBndr, CoreExpr)] CoreM CoreExpr
forall a b. (a -> b) -> a -> b
$ CoreExpr -> CoreM CoreExpr
toSingleTick CoreExpr
e
  | Bool
otherwise = do
      CoreExpr
e1' <- CoreExpr -> WriterT [(CoreBndr, CoreExpr)] CoreM CoreExpr
extractAdv CoreExpr
e1
      CoreExpr
e2' <- CoreExpr -> WriterT [(CoreBndr, CoreExpr)] CoreM CoreExpr
extractAdv CoreExpr
e2
      CoreExpr -> WriterT [(CoreBndr, CoreExpr)] CoreM CoreExpr
forall (m :: * -> *) a. Monad m => a -> m a
return (CoreExpr -> CoreExpr -> CoreExpr
forall b. Expr b -> Expr b -> Expr b
App CoreExpr
e1' CoreExpr
e2')
extractAdv (Lam CoreBndr
x CoreExpr
e) = do
  (CoreExpr
e', [(CoreBndr, CoreExpr, CoreExpr)]
advs) <- CoreM (CoreExpr, [(CoreBndr, CoreExpr, CoreExpr)])
-> WriterT
     [(CoreBndr, CoreExpr)]
     CoreM
     (CoreExpr, [(CoreBndr, CoreExpr, CoreExpr)])
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (CoreM (CoreExpr, [(CoreBndr, CoreExpr, CoreExpr)])
 -> WriterT
      [(CoreBndr, CoreExpr)]
      CoreM
      (CoreExpr, [(CoreBndr, CoreExpr, CoreExpr)]))
-> CoreM (CoreExpr, [(CoreBndr, CoreExpr, CoreExpr)])
-> WriterT
     [(CoreBndr, CoreExpr)]
     CoreM
     (CoreExpr, [(CoreBndr, CoreExpr, CoreExpr)])
forall a b. (a -> b) -> a -> b
$ WriterT [(CoreBndr, CoreExpr, CoreExpr)] CoreM CoreExpr
-> CoreM (CoreExpr, [(CoreBndr, CoreExpr, CoreExpr)])
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (CoreExpr -> WriterT [(CoreBndr, CoreExpr, CoreExpr)] CoreM CoreExpr
extractAdv' CoreExpr
e)
  [(CoreBndr, CoreExpr)]
advs' <- ((CoreBndr, CoreExpr, CoreExpr)
 -> WriterT [(CoreBndr, CoreExpr)] CoreM (CoreBndr, CoreExpr))
-> [(CoreBndr, CoreExpr, CoreExpr)]
-> WriterT [(CoreBndr, CoreExpr)] CoreM [(CoreBndr, CoreExpr)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\ (CoreBndr
x,CoreExpr
a,CoreExpr
b) -> (CoreExpr -> (CoreBndr, CoreExpr))
-> WriterT [(CoreBndr, CoreExpr)] CoreM CoreExpr
-> WriterT [(CoreBndr, CoreExpr)] CoreM (CoreBndr, CoreExpr)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\CoreExpr
b' -> (CoreBndr
x,CoreExpr
b')) (CoreExpr
-> CoreExpr -> WriterT [(CoreBndr, CoreExpr)] CoreM CoreExpr
extractAdvApp CoreExpr
a CoreExpr
b)) [(CoreBndr, CoreExpr, CoreExpr)]
advs
  CoreExpr -> WriterT [(CoreBndr, CoreExpr)] CoreM CoreExpr
forall (m :: * -> *) a. Monad m => a -> m a
return ([(CoreBndr, CoreExpr)] -> CoreExpr -> CoreExpr
foldLets [(CoreBndr, CoreExpr)]
advs' (CoreBndr -> CoreExpr -> CoreExpr
forall b. b -> Expr b -> Expr b
Lam CoreBndr
x CoreExpr
e'))
extractAdv (Case CoreExpr
e CoreBndr
b Type
ty [Alt CoreBndr]
alts) = do
  CoreExpr
e' <- CoreExpr -> WriterT [(CoreBndr, CoreExpr)] CoreM CoreExpr
extractAdv CoreExpr
e
  [Alt CoreBndr]
alts' <- (Alt CoreBndr
 -> WriterT [(CoreBndr, CoreExpr)] CoreM (Alt CoreBndr))
-> [Alt CoreBndr]
-> WriterT [(CoreBndr, CoreExpr)] CoreM [Alt CoreBndr]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\ (AltCon
c,[CoreBndr]
bs,CoreExpr
f) -> (CoreExpr -> Alt CoreBndr)
-> WriterT [(CoreBndr, CoreExpr)] CoreM CoreExpr
-> WriterT [(CoreBndr, CoreExpr)] CoreM (Alt CoreBndr)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\ CoreExpr
x ->(AltCon
c,[CoreBndr]
bs,CoreExpr
x)) (CoreExpr -> WriterT [(CoreBndr, CoreExpr)] CoreM CoreExpr
extractAdv CoreExpr
f)) [Alt CoreBndr]
alts
  CoreExpr -> WriterT [(CoreBndr, CoreExpr)] CoreM CoreExpr
forall (m :: * -> *) a. Monad m => a -> m a
return (CoreExpr -> CoreBndr -> Type -> [Alt CoreBndr] -> CoreExpr
forall b. Expr b -> b -> Type -> [Alt b] -> Expr b
Case CoreExpr
e' CoreBndr
b Type
ty [Alt CoreBndr]
alts')
extractAdv (Cast CoreExpr
e Coercion
c) = do
  CoreExpr
e' <- CoreExpr -> WriterT [(CoreBndr, CoreExpr)] CoreM CoreExpr
extractAdv CoreExpr
e
  CoreExpr -> WriterT [(CoreBndr, CoreExpr)] CoreM CoreExpr
forall (m :: * -> *) a. Monad m => a -> m a
return (CoreExpr -> Coercion -> CoreExpr
forall b. Expr b -> Coercion -> Expr b
Cast CoreExpr
e' Coercion
c)
extractAdv (Tick Tickish CoreBndr
t CoreExpr
e) = do
  CoreExpr
e' <- CoreExpr -> WriterT [(CoreBndr, CoreExpr)] CoreM CoreExpr
extractAdv CoreExpr
e
  CoreExpr -> WriterT [(CoreBndr, CoreExpr)] CoreM CoreExpr
forall (m :: * -> *) a. Monad m => a -> m a
return (Tickish CoreBndr -> CoreExpr -> CoreExpr
forall b. Tickish CoreBndr -> Expr b -> Expr b
Tick Tickish CoreBndr
t CoreExpr
e')
extractAdv e :: CoreExpr
e@(Let Rec{} CoreExpr
_) = CoreM CoreExpr -> WriterT [(CoreBndr, CoreExpr)] CoreM CoreExpr
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (CoreM CoreExpr -> WriterT [(CoreBndr, CoreExpr)] CoreM CoreExpr)
-> CoreM CoreExpr -> WriterT [(CoreBndr, CoreExpr)] CoreM CoreExpr
forall a b. (a -> b) -> a -> b
$ CoreExpr -> CoreM CoreExpr
toSingleTick CoreExpr
e
extractAdv (Let (NonRec CoreBndr
b CoreExpr
e1) CoreExpr
e2) = do
  CoreExpr
e1' <- CoreExpr -> WriterT [(CoreBndr, CoreExpr)] CoreM CoreExpr
extractAdv CoreExpr
e1
  CoreExpr
e2' <- CoreExpr -> WriterT [(CoreBndr, CoreExpr)] CoreM CoreExpr
extractAdv CoreExpr
e2
  CoreExpr -> WriterT [(CoreBndr, CoreExpr)] CoreM CoreExpr
forall (m :: * -> *) a. Monad m => a -> m a
return (Bind CoreBndr -> CoreExpr -> CoreExpr
forall b. Bind b -> Expr b -> Expr b
Let (CoreBndr -> CoreExpr -> Bind CoreBndr
forall b. b -> Expr b -> Bind b
NonRec CoreBndr
b CoreExpr
e1') CoreExpr
e2')
extractAdv e :: CoreExpr
e@Type{} = CoreExpr -> WriterT [(CoreBndr, CoreExpr)] CoreM CoreExpr
forall (m :: * -> *) a. Monad m => a -> m a
return CoreExpr
e
extractAdv e :: CoreExpr
e@Var{} = CoreExpr -> WriterT [(CoreBndr, CoreExpr)] CoreM CoreExpr
forall (m :: * -> *) a. Monad m => a -> m a
return CoreExpr
e
extractAdv e :: CoreExpr
e@Lit{} = CoreExpr -> WriterT [(CoreBndr, CoreExpr)] CoreM CoreExpr
forall (m :: * -> *) a. Monad m => a -> m a
return CoreExpr
e
extractAdv e :: CoreExpr
e@Coercion{} = CoreExpr -> WriterT [(CoreBndr, CoreExpr)] CoreM CoreExpr
forall (m :: * -> *) a. Monad m => a -> m a
return CoreExpr
e

-- This is used to pull adv out of lambdas. The writer monad returns
-- mappings from fresh variables to occurrences of adv and the term it
-- is applied to.
-- 
-- That is occurrences of @adv t@ are replaced with a fresh variable
-- @x@ and the triple @(x,adv,t)@ is returned in the writer monad.
extractAdv' :: CoreExpr -> WriterT [(Id,CoreExpr,CoreExpr)] CoreM CoreExpr
extractAdv' :: CoreExpr -> WriterT [(CoreBndr, CoreExpr, CoreExpr)] CoreM CoreExpr
extractAdv' e :: CoreExpr
e@(App CoreExpr
e1 CoreExpr
e2)
  | CoreExpr -> Bool
isAdvApp CoreExpr
e1 = do
       CoreBndr
x <- CoreM CoreBndr
-> WriterT [(CoreBndr, CoreExpr, CoreExpr)] CoreM CoreBndr
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (FastString -> CoreExpr -> CoreM CoreBndr
forall (m :: * -> *).
MonadUnique m =>
FastString -> CoreExpr -> m CoreBndr
mkSysLocalFromExpr (String -> FastString
fsLit String
"adv") CoreExpr
e)
       [(CoreBndr, CoreExpr, CoreExpr)]
-> WriterT [(CoreBndr, CoreExpr, CoreExpr)] CoreM ()
forall (m :: * -> *) w. Monad m => w -> WriterT w m ()
tell [(CoreBndr
x,CoreExpr
e1,CoreExpr
e2)]
       CoreExpr -> WriterT [(CoreBndr, CoreExpr, CoreExpr)] CoreM CoreExpr
forall (m :: * -> *) a. Monad m => a -> m a
return (CoreBndr -> CoreExpr
forall b. CoreBndr -> Expr b
Var CoreBndr
x)
  | CoreExpr -> Bool
isDelayApp CoreExpr
e1 = do
      (CoreExpr
e2', [(CoreBndr, CoreExpr)]
advs) <- CoreM (CoreExpr, [(CoreBndr, CoreExpr)])
-> WriterT
     [(CoreBndr, CoreExpr, CoreExpr)]
     CoreM
     (CoreExpr, [(CoreBndr, CoreExpr)])
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (CoreM (CoreExpr, [(CoreBndr, CoreExpr)])
 -> WriterT
      [(CoreBndr, CoreExpr, CoreExpr)]
      CoreM
      (CoreExpr, [(CoreBndr, CoreExpr)]))
-> CoreM (CoreExpr, [(CoreBndr, CoreExpr)])
-> WriterT
     [(CoreBndr, CoreExpr, CoreExpr)]
     CoreM
     (CoreExpr, [(CoreBndr, CoreExpr)])
forall a b. (a -> b) -> a -> b
$ WriterT [(CoreBndr, CoreExpr)] CoreM CoreExpr
-> CoreM (CoreExpr, [(CoreBndr, CoreExpr)])
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (CoreExpr -> WriterT [(CoreBndr, CoreExpr)] CoreM CoreExpr
extractAdv CoreExpr
e2)
      [(CoreBndr, CoreExpr)]
advs' <- ((CoreBndr, CoreExpr)
 -> WriterT
      [(CoreBndr, CoreExpr, CoreExpr)] CoreM (CoreBndr, CoreExpr))
-> [(CoreBndr, CoreExpr)]
-> WriterT
     [(CoreBndr, CoreExpr, CoreExpr)] CoreM [(CoreBndr, CoreExpr)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((CoreExpr
 -> WriterT [(CoreBndr, CoreExpr, CoreExpr)] CoreM CoreExpr)
-> (CoreBndr, CoreExpr)
-> WriterT
     [(CoreBndr, CoreExpr, CoreExpr)] CoreM (CoreBndr, CoreExpr)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM CoreExpr -> WriterT [(CoreBndr, CoreExpr, CoreExpr)] CoreM CoreExpr
extractAdv') [(CoreBndr, CoreExpr)]
advs
      CoreExpr -> WriterT [(CoreBndr, CoreExpr, CoreExpr)] CoreM CoreExpr
forall (m :: * -> *) a. Monad m => a -> m a
return ([(CoreBndr, CoreExpr)] -> CoreExpr -> CoreExpr
foldLets [(CoreBndr, CoreExpr)]
advs' (CoreExpr -> CoreExpr -> CoreExpr
forall b. Expr b -> Expr b -> Expr b
App CoreExpr
e1 CoreExpr
e2'))
  | CoreExpr -> Bool
isBoxApp CoreExpr
e1 = CoreM CoreExpr
-> WriterT [(CoreBndr, CoreExpr, CoreExpr)] CoreM CoreExpr
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (CoreM CoreExpr
 -> WriterT [(CoreBndr, CoreExpr, CoreExpr)] CoreM CoreExpr)
-> CoreM CoreExpr
-> WriterT [(CoreBndr, CoreExpr, CoreExpr)] CoreM CoreExpr
forall a b. (a -> b) -> a -> b
$ CoreExpr -> CoreM CoreExpr
toSingleTick CoreExpr
e
  | Bool
otherwise = do
      CoreExpr
e1' <- CoreExpr -> WriterT [(CoreBndr, CoreExpr, CoreExpr)] CoreM CoreExpr
extractAdv' CoreExpr
e1
      CoreExpr
e2' <- CoreExpr -> WriterT [(CoreBndr, CoreExpr, CoreExpr)] CoreM CoreExpr
extractAdv' CoreExpr
e2
      CoreExpr -> WriterT [(CoreBndr, CoreExpr, CoreExpr)] CoreM CoreExpr
forall (m :: * -> *) a. Monad m => a -> m a
return (CoreExpr -> CoreExpr -> CoreExpr
forall b. Expr b -> Expr b -> Expr b
App CoreExpr
e1' CoreExpr
e2')
extractAdv' (Lam CoreBndr
x CoreExpr
e) = do
  CoreExpr
e' <- CoreExpr -> WriterT [(CoreBndr, CoreExpr, CoreExpr)] CoreM CoreExpr
extractAdv' CoreExpr
e
  CoreExpr -> WriterT [(CoreBndr, CoreExpr, CoreExpr)] CoreM CoreExpr
forall (m :: * -> *) a. Monad m => a -> m a
return (CoreBndr -> CoreExpr -> CoreExpr
forall b. b -> Expr b -> Expr b
Lam CoreBndr
x CoreExpr
e')
extractAdv' (Case CoreExpr
e CoreBndr
b Type
ty [Alt CoreBndr]
alts) = do
  CoreExpr
e' <- CoreExpr -> WriterT [(CoreBndr, CoreExpr, CoreExpr)] CoreM CoreExpr
extractAdv' CoreExpr
e
  [Alt CoreBndr]
alts' <- (Alt CoreBndr
 -> WriterT [(CoreBndr, CoreExpr, CoreExpr)] CoreM (Alt CoreBndr))
-> [Alt CoreBndr]
-> WriterT [(CoreBndr, CoreExpr, CoreExpr)] CoreM [Alt CoreBndr]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\ (AltCon
c,[CoreBndr]
bs,CoreExpr
f) -> (CoreExpr -> Alt CoreBndr)
-> WriterT [(CoreBndr, CoreExpr, CoreExpr)] CoreM CoreExpr
-> WriterT [(CoreBndr, CoreExpr, CoreExpr)] CoreM (Alt CoreBndr)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\ CoreExpr
x ->(AltCon
c,[CoreBndr]
bs,CoreExpr
x)) (CoreExpr -> WriterT [(CoreBndr, CoreExpr, CoreExpr)] CoreM CoreExpr
extractAdv' CoreExpr
f)) [Alt CoreBndr]
alts
  CoreExpr -> WriterT [(CoreBndr, CoreExpr, CoreExpr)] CoreM CoreExpr
forall (m :: * -> *) a. Monad m => a -> m a
return (CoreExpr -> CoreBndr -> Type -> [Alt CoreBndr] -> CoreExpr
forall b. Expr b -> b -> Type -> [Alt b] -> Expr b
Case CoreExpr
e' CoreBndr
b Type
ty [Alt CoreBndr]
alts')
extractAdv' (Cast CoreExpr
e Coercion
c) = do
  CoreExpr
e' <- CoreExpr -> WriterT [(CoreBndr, CoreExpr, CoreExpr)] CoreM CoreExpr
extractAdv' CoreExpr
e
  CoreExpr -> WriterT [(CoreBndr, CoreExpr, CoreExpr)] CoreM CoreExpr
forall (m :: * -> *) a. Monad m => a -> m a
return (CoreExpr -> Coercion -> CoreExpr
forall b. Expr b -> Coercion -> Expr b
Cast CoreExpr
e' Coercion
c)
extractAdv' (Tick Tickish CoreBndr
t CoreExpr
e) = do
  CoreExpr
e' <- CoreExpr -> WriterT [(CoreBndr, CoreExpr, CoreExpr)] CoreM CoreExpr
extractAdv' CoreExpr
e
  CoreExpr -> WriterT [(CoreBndr, CoreExpr, CoreExpr)] CoreM CoreExpr
forall (m :: * -> *) a. Monad m => a -> m a
return (Tickish CoreBndr -> CoreExpr -> CoreExpr
forall b. Tickish CoreBndr -> Expr b -> Expr b
Tick Tickish CoreBndr
t CoreExpr
e')
extractAdv' e :: CoreExpr
e@(Let Rec{} CoreExpr
_) = CoreM CoreExpr
-> WriterT [(CoreBndr, CoreExpr, CoreExpr)] CoreM CoreExpr
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (CoreM CoreExpr
 -> WriterT [(CoreBndr, CoreExpr, CoreExpr)] CoreM CoreExpr)
-> CoreM CoreExpr
-> WriterT [(CoreBndr, CoreExpr, CoreExpr)] CoreM CoreExpr
forall a b. (a -> b) -> a -> b
$ CoreExpr -> CoreM CoreExpr
toSingleTick CoreExpr
e
extractAdv' (Let (NonRec CoreBndr
b CoreExpr
e1) CoreExpr
e2) = do
  CoreExpr
e1' <- CoreExpr -> WriterT [(CoreBndr, CoreExpr, CoreExpr)] CoreM CoreExpr
extractAdv' CoreExpr
e1
  CoreExpr
e2' <- CoreExpr -> WriterT [(CoreBndr, CoreExpr, CoreExpr)] CoreM CoreExpr
extractAdv' CoreExpr
e2
  CoreExpr -> WriterT [(CoreBndr, CoreExpr, CoreExpr)] CoreM CoreExpr
forall (m :: * -> *) a. Monad m => a -> m a
return (Bind CoreBndr -> CoreExpr -> CoreExpr
forall b. Bind b -> Expr b -> Expr b
Let (CoreBndr -> CoreExpr -> Bind CoreBndr
forall b. b -> Expr b -> Bind b
NonRec CoreBndr
b CoreExpr
e1') CoreExpr
e2')
extractAdv' e :: CoreExpr
e@Type{} = CoreExpr -> WriterT [(CoreBndr, CoreExpr, CoreExpr)] CoreM CoreExpr
forall (m :: * -> *) a. Monad m => a -> m a
return CoreExpr
e
extractAdv' e :: CoreExpr
e@Var{} = CoreExpr -> WriterT [(CoreBndr, CoreExpr, CoreExpr)] CoreM CoreExpr
forall (m :: * -> *) a. Monad m => a -> m a
return CoreExpr
e
extractAdv' e :: CoreExpr
e@Lit{} = CoreExpr -> WriterT [(CoreBndr, CoreExpr, CoreExpr)] CoreM CoreExpr
forall (m :: * -> *) a. Monad m => a -> m a
return CoreExpr
e
extractAdv' e :: CoreExpr
e@Coercion{} = CoreExpr -> WriterT [(CoreBndr, CoreExpr, CoreExpr)] CoreM CoreExpr
forall (m :: * -> *) a. Monad m => a -> m a
return CoreExpr
e



isDelayApp :: CoreExpr -> Bool
isDelayApp :: CoreExpr -> Bool
isDelayApp = (String -> Bool) -> CoreExpr -> Bool
isPrimApp (\String
occ -> String
occ String -> String -> Bool
forall a. Eq a => a -> a -> Bool
== String
"delay")

isBoxApp :: CoreExpr -> Bool
isBoxApp :: CoreExpr -> Bool
isBoxApp = (String -> Bool) -> CoreExpr -> Bool
isPrimApp (\String
occ -> String
occ String -> String -> Bool
forall a. Eq a => a -> a -> Bool
== String
"Box" Bool -> Bool -> Bool
|| String
occ String -> String -> Bool
forall a. Eq a => a -> a -> Bool
== String
"box")

isAdvApp :: CoreExpr -> Bool
isAdvApp :: CoreExpr -> Bool
isAdvApp = (String -> Bool) -> CoreExpr -> Bool
isPrimApp (\String
occ -> String
occ String -> String -> Bool
forall a. Eq a => a -> a -> Bool
== String
"adv")


isPrimApp :: (String -> Bool) -> CoreExpr -> Bool
isPrimApp :: (String -> Bool) -> CoreExpr -> Bool
isPrimApp String -> Bool
p (App CoreExpr
e CoreExpr
e')
  | CoreExpr -> Bool
forall b. Expr b -> Bool
isType CoreExpr
e' Bool -> Bool -> Bool
|| Bool -> Bool
not  (Type -> Bool
tcIsLiftedTypeKind(HasDebugCallStack => Type -> Type
Type -> Type
typeKind (CoreExpr -> Type
exprType CoreExpr
e'))) = (String -> Bool) -> CoreExpr -> Bool
isPrimApp String -> Bool
p CoreExpr
e
  | Bool
otherwise = Bool
False
isPrimApp String -> Bool
p (Cast CoreExpr
e Coercion
_) = (String -> Bool) -> CoreExpr -> Bool
isPrimApp String -> Bool
p CoreExpr
e
isPrimApp String -> Bool
p (Tick Tickish CoreBndr
_ CoreExpr
e) = (String -> Bool) -> CoreExpr -> Bool
isPrimApp String -> Bool
p CoreExpr
e
isPrimApp String -> Bool
p (Var CoreBndr
v) = (String -> Bool) -> CoreBndr -> Bool
isPrimVar String -> Bool
p CoreBndr
v
isPrimApp String -> Bool
_ CoreExpr
_ = Bool
False

isPrimVar :: (String -> Bool) -> Var -> Bool
isPrimVar :: (String -> Bool) -> CoreBndr -> Bool
isPrimVar String -> Bool
p CoreBndr
v = Bool -> (Bool -> Bool) -> Maybe Bool -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False Bool -> Bool
forall a. a -> a
id (Maybe Bool -> Bool) -> Maybe Bool -> Bool
forall a b. (a -> b) -> a -> b
$ do
  let name :: Name
name = CoreBndr -> Name
varName CoreBndr
v
  Module
mod <- Name -> Maybe Module
nameModule_maybe Name
name
  let occ :: String
occ = Name -> String
forall a. NamedThing a => a -> String
getOccString Name
name
  Bool -> Maybe Bool
forall (m :: * -> *) a. Monad m => a -> m a
return (String -> Bool
p String
occ
          Bool -> Bool -> Bool
&& ((ModuleName -> String
moduleNameString (Module -> ModuleName
moduleName Module
mod) String -> String -> Bool
forall a. Eq a => a -> a -> Bool
== String
"Rattus.Internal") Bool -> Bool -> Bool
||
          ModuleName -> String
moduleNameString (Module -> ModuleName
moduleName Module
mod) String -> String -> Bool
forall a. Eq a => a -> a -> Bool
== String
"Rattus.Primitives"))