-- | Facilities for changing the representation of some fragment,
-- within a monadic context.  We call this "rephrasing", for no deep
-- reason.
module Futhark.IR.Rephrase
  ( rephraseProg,
    rephraseFunDef,
    rephraseExp,
    rephraseBody,
    rephraseStm,
    rephraseLambda,
    rephrasePat,
    rephrasePatElem,
    Rephraser (..),
    RephraseOp (..),
  )
where

import Futhark.IR.Syntax
import Futhark.IR.Traversals

-- | A collection of functions that together allow us to rephrase some
-- IR fragment, in some monad @m@.  If we let @m@ be the 'Maybe'
-- monad, we can conveniently do rephrasing that might fail.  This is
-- useful if you want to see if some IR in e.g. the @Kernels@ rep
-- actually uses any @Kernels@-specific operations.
data Rephraser m from to = Rephraser
  { forall (m :: * -> *) from to.
Rephraser m from to -> ExpDec from -> m (ExpDec to)
rephraseExpDec :: ExpDec from -> m (ExpDec to),
    forall (m :: * -> *) from to.
Rephraser m from to -> LetDec from -> m (LetDec to)
rephraseLetBoundDec :: LetDec from -> m (LetDec to),
    forall (m :: * -> *) from to.
Rephraser m from to -> FParamInfo from -> m (FParamInfo to)
rephraseFParamDec :: FParamInfo from -> m (FParamInfo to),
    forall (m :: * -> *) from to.
Rephraser m from to -> LParamInfo from -> m (LParamInfo to)
rephraseLParamDec :: LParamInfo from -> m (LParamInfo to),
    forall (m :: * -> *) from to.
Rephraser m from to -> BodyDec from -> m (BodyDec to)
rephraseBodyDec :: BodyDec from -> m (BodyDec to),
    forall (m :: * -> *) from to.
Rephraser m from to -> RetType from -> m (RetType to)
rephraseRetType :: RetType from -> m (RetType to),
    forall (m :: * -> *) from to.
Rephraser m from to -> BranchType from -> m (BranchType to)
rephraseBranchType :: BranchType from -> m (BranchType to),
    forall (m :: * -> *) from to.
Rephraser m from to -> Op from -> m (Op to)
rephraseOp :: Op from -> m (Op to)
  }

-- | Rephrase an entire program.
rephraseProg :: Monad m => Rephraser m from to -> Prog from -> m (Prog to)
rephraseProg :: forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Prog from -> m (Prog to)
rephraseProg Rephraser m from to
rephraser Prog from
prog = do
  Seq (Stm to)
consts <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Stm from -> m (Stm to)
rephraseStm Rephraser m from to
rephraser) (forall rep. Prog rep -> Stms rep
progConsts Prog from
prog)
  [FunDef to]
funs <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> FunDef from -> m (FunDef to)
rephraseFunDef Rephraser m from to
rephraser) (forall rep. Prog rep -> [FunDef rep]
progFuns Prog from
prog)
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Prog from
prog {progConsts :: Seq (Stm to)
progConsts = Seq (Stm to)
consts, progFuns :: [FunDef to]
progFuns = [FunDef to]
funs}

-- | Rephrase a function definition.
rephraseFunDef :: Monad m => Rephraser m from to -> FunDef from -> m (FunDef to)
rephraseFunDef :: forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> FunDef from -> m (FunDef to)
rephraseFunDef Rephraser m from to
rephraser FunDef from
fundec = do
  Body to
body' <- forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Body from -> m (Body to)
rephraseBody Rephraser m from to
rephraser forall a b. (a -> b) -> a -> b
$ forall rep. FunDef rep -> Body rep
funDefBody FunDef from
fundec
  [Param (FParamInfo to)]
params' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *) from to.
Monad m =>
(from -> m to) -> Param from -> m (Param to)
rephraseParam forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) from to.
Rephraser m from to -> FParamInfo from -> m (FParamInfo to)
rephraseFParamDec Rephraser m from to
rephraser) forall a b. (a -> b) -> a -> b
$ forall rep. FunDef rep -> [FParam rep]
funDefParams FunDef from
fundec
  [RetType to]
rettype' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *) from to.
Rephraser m from to -> RetType from -> m (RetType to)
rephraseRetType Rephraser m from to
rephraser) forall a b. (a -> b) -> a -> b
$ forall rep. FunDef rep -> [RetType rep]
funDefRetType FunDef from
fundec
  forall (f :: * -> *) a. Applicative f => a -> f a
pure FunDef from
fundec {funDefBody :: Body to
funDefBody = Body to
body', funDefParams :: [Param (FParamInfo to)]
funDefParams = [Param (FParamInfo to)]
params', funDefRetType :: [RetType to]
funDefRetType = [RetType to]
rettype'}

-- | Rephrase an expression.
rephraseExp :: Monad m => Rephraser m from to -> Exp from -> m (Exp to)
rephraseExp :: forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Exp from -> m (Exp to)
rephraseExp = forall (m :: * -> *) frep trep.
Monad m =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Mapper from to m
mapper

-- | Rephrase a statement.
rephraseStm :: Monad m => Rephraser m from to -> Stm from -> m (Stm to)
rephraseStm :: forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Stm from -> m (Stm to)
rephraseStm Rephraser m from to
rephraser (Let Pat (LetDec from)
pat (StmAux Certs
cs Attrs
attrs ExpDec from
dec) Exp from
e) =
  forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let
    forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) from to.
Monad m =>
(from -> m to) -> Pat from -> m (Pat to)
rephrasePat (forall (m :: * -> *) from to.
Rephraser m from to -> LetDec from -> m (LetDec to)
rephraseLetBoundDec Rephraser m from to
rephraser) Pat (LetDec from)
pat
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux Certs
cs Attrs
attrs forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) from to.
Rephraser m from to -> ExpDec from -> m (ExpDec to)
rephraseExpDec Rephraser m from to
rephraser ExpDec from
dec)
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Exp from -> m (Exp to)
rephraseExp Rephraser m from to
rephraser Exp from
e

-- | Rephrase a pattern.
rephrasePat ::
  Monad m =>
  (from -> m to) ->
  Pat from ->
  m (Pat to)
rephrasePat :: forall (m :: * -> *) from to.
Monad m =>
(from -> m to) -> Pat from -> m (Pat to)
rephrasePat = forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse

-- | Rephrase a pattern element.
rephrasePatElem :: Monad m => (from -> m to) -> PatElem from -> m (PatElem to)
rephrasePatElem :: forall (m :: * -> *) from to.
Monad m =>
(from -> m to) -> PatElem from -> m (PatElem to)
rephrasePatElem from -> m to
rephraser (PatElem VName
ident from
from) =
  forall dec. VName -> dec -> PatElem dec
PatElem VName
ident forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> from -> m to
rephraser from
from

-- | Rephrase a parameter.
rephraseParam :: Monad m => (from -> m to) -> Param from -> m (Param to)
rephraseParam :: forall (m :: * -> *) from to.
Monad m =>
(from -> m to) -> Param from -> m (Param to)
rephraseParam from -> m to
rephraser (Param Attrs
attrs VName
name from
from) =
  forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs VName
name forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> from -> m to
rephraser from
from

-- | Rephrase a body.
rephraseBody :: Monad m => Rephraser m from to -> Body from -> m (Body to)
rephraseBody :: forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Body from -> m (Body to)
rephraseBody Rephraser m from to
rephraser (Body BodyDec from
rep Stms from
stms Result
res) =
  forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body
    forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) from to.
Rephraser m from to -> BodyDec from -> m (BodyDec to)
rephraseBodyDec Rephraser m from to
rephraser BodyDec from
rep
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (forall rep. [Stm rep] -> Stms rep
stmsFromList forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Stm from -> m (Stm to)
rephraseStm Rephraser m from to
rephraser) (forall rep. Stms rep -> [Stm rep]
stmsToList Stms from
stms))
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res

-- | Rephrase a lambda.
rephraseLambda :: Monad m => Rephraser m from to -> Lambda from -> m (Lambda to)
rephraseLambda :: forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Lambda from -> m (Lambda to)
rephraseLambda Rephraser m from to
rephraser Lambda from
lam = do
  Body to
body' <- forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Body from -> m (Body to)
rephraseBody Rephraser m from to
rephraser forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> Body rep
lambdaBody Lambda from
lam
  [Param (LParamInfo to)]
params' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *) from to.
Monad m =>
(from -> m to) -> Param from -> m (Param to)
rephraseParam forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) from to.
Rephraser m from to -> LParamInfo from -> m (LParamInfo to)
rephraseLParamDec Rephraser m from to
rephraser) forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda from
lam
  forall (f :: * -> *) a. Applicative f => a -> f a
pure Lambda from
lam {lambdaBody :: Body to
lambdaBody = Body to
body', lambdaParams :: [Param (LParamInfo to)]
lambdaParams = [Param (LParamInfo to)]
params'}

mapper :: Monad m => Rephraser m from to -> Mapper from to m
mapper :: forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Mapper from to m
mapper Rephraser m from to
rephraser =
  forall (m :: * -> *) rep. Monad m => Mapper rep rep m
identityMapper
    { mapOnBody :: Scope to -> Body from -> m (Body to)
mapOnBody = forall a b. a -> b -> a
const forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Body from -> m (Body to)
rephraseBody Rephraser m from to
rephraser,
      mapOnRetType :: RetType from -> m (RetType to)
mapOnRetType = forall (m :: * -> *) from to.
Rephraser m from to -> RetType from -> m (RetType to)
rephraseRetType Rephraser m from to
rephraser,
      mapOnBranchType :: BranchType from -> m (BranchType to)
mapOnBranchType = forall (m :: * -> *) from to.
Rephraser m from to -> BranchType from -> m (BranchType to)
rephraseBranchType Rephraser m from to
rephraser,
      mapOnFParam :: FParam from -> m (FParam to)
mapOnFParam = forall (m :: * -> *) from to.
Monad m =>
(from -> m to) -> Param from -> m (Param to)
rephraseParam (forall (m :: * -> *) from to.
Rephraser m from to -> FParamInfo from -> m (FParamInfo to)
rephraseFParamDec Rephraser m from to
rephraser),
      mapOnLParam :: LParam from -> m (LParam to)
mapOnLParam = forall (m :: * -> *) from to.
Monad m =>
(from -> m to) -> Param from -> m (Param to)
rephraseParam (forall (m :: * -> *) from to.
Rephraser m from to -> LParamInfo from -> m (LParamInfo to)
rephraseLParamDec Rephraser m from to
rephraser),
      mapOnOp :: Op from -> m (Op to)
mapOnOp = forall (m :: * -> *) from to.
Rephraser m from to -> Op from -> m (Op to)
rephraseOp Rephraser m from to
rephraser
    }

-- | Rephrasing any fragments inside an Op from one representation to
-- another.
class RephraseOp op where
  rephraseInOp :: Monad m => Rephraser m from to -> op from -> m (op to)

instance RephraseOp NoOp where
  rephraseInOp :: forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> NoOp from -> m (NoOp to)
rephraseInOp Rephraser m from to
_ NoOp from
NoOp = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall {k} (rep :: k). NoOp rep
NoOp